| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376 |
- import json
- import logging
- import re
- import secrets
- import string
- import struct
- import subprocess
- import time
- import uuid
- from collections.abc import Generator, Mapping
- from datetime import datetime
- from hashlib import sha256
- from typing import TYPE_CHECKING, Any, Optional, Union, cast
- from zoneinfo import available_timezones
-
- from flask import Response, stream_with_context
- from flask_restx import fields
- from pydantic import BaseModel
-
- from configs import dify_config
- from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
- from core.file import helpers as file_helpers
- from core.model_runtime.utils.encoders import jsonable_encoder
- from extensions.ext_redis import redis_client
-
- if TYPE_CHECKING:
- from models.account import Account
- from models.model import EndUser
-
- logger = logging.getLogger(__name__)
-
-
- def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
- """
- Extract tenant_id from Account or EndUser object.
-
- Args:
- user: Account or EndUser object
-
- Returns:
- tenant_id string if available, None otherwise
-
- Raises:
- ValueError: If user is neither Account nor EndUser
- """
- from models.account import Account
- from models.model import EndUser
-
- if isinstance(user, Account):
- return user.current_tenant_id
- elif isinstance(user, EndUser):
- return user.tenant_id
- else:
- raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.")
-
-
- def run(script):
- return subprocess.getstatusoutput("source /root/.bashrc && " + script)
-
-
- class AppIconUrlField(fields.Raw):
- def output(self, key, obj, **kwargs):
- if obj is None:
- return None
-
- from models.model import App, IconType, Site
-
- if isinstance(obj, dict) and "app" in obj:
- obj = obj["app"]
-
- if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value:
- return file_helpers.get_signed_file_url(obj.icon)
- return None
-
-
- class AvatarUrlField(fields.Raw):
- def output(self, key, obj, **kwargs):
- if obj is None:
- return None
-
- from models.account import Account
-
- if isinstance(obj, Account) and obj.avatar is not None:
- return file_helpers.get_signed_file_url(obj.avatar)
- return None
-
-
- class TimestampField(fields.Raw):
- def format(self, value) -> int:
- return int(value.timestamp())
-
-
- def email(email):
- # Define a regex pattern for email addresses
- pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$"
- # Check if the email matches the pattern
- if re.match(pattern, email) is not None:
- return email
-
- error = f"{email} is not a valid email."
- raise ValueError(error)
-
-
- def uuid_value(value):
- if value == "":
- return str(value)
-
- try:
- uuid_obj = uuid.UUID(value)
- return str(uuid_obj)
- except ValueError:
- error = f"{value} is not a valid uuid."
- raise ValueError(error)
-
-
- def alphanumeric(value: str):
- # check if the value is alphanumeric and underlined
- if re.match(r"^[a-zA-Z0-9_]+$", value):
- return value
-
- raise ValueError(f"{value} is not a valid alphanumeric value")
-
-
- def timestamp_value(timestamp):
- try:
- int_timestamp = int(timestamp)
- if int_timestamp < 0:
- raise ValueError
- return int_timestamp
- except ValueError:
- error = f"{timestamp} is not a valid timestamp."
- raise ValueError(error)
-
-
- class StrLen:
- """Restrict input to an integer in a range (inclusive)"""
-
- def __init__(self, max_length, argument="argument"):
- self.max_length = max_length
- self.argument = argument
-
- def __call__(self, value):
- length = len(value)
- if length > self.max_length:
- error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format(
- arg=self.argument, val=value, length=self.max_length
- )
- raise ValueError(error)
-
- return value
-
-
- class DatetimeString:
- def __init__(self, format, argument="argument"):
- self.format = format
- self.argument = argument
-
- def __call__(self, value):
- try:
- datetime.strptime(value, self.format)
- except ValueError:
- error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format(
- arg=self.argument, val=value, format=self.format
- )
- raise ValueError(error)
-
- return value
-
-
- def _get_float(value):
- try:
- return float(value)
- except (TypeError, ValueError):
- raise ValueError(f"{value} is not a valid float")
-
-
- def timezone(timezone_string):
- if timezone_string and timezone_string in available_timezones():
- return timezone_string
-
- error = f"{timezone_string} is not a valid timezone."
- raise ValueError(error)
-
-
- def generate_string(n):
- letters_digits = string.ascii_letters + string.digits
- result = ""
- for i in range(n):
- result += secrets.choice(letters_digits)
-
- return result
-
-
- def extract_remote_ip(request) -> str:
- if request.headers.get("CF-Connecting-IP"):
- return cast(str, request.headers.get("CF-Connecting-IP"))
- elif request.headers.getlist("X-Forwarded-For"):
- return cast(str, request.headers.getlist("X-Forwarded-For")[0])
- else:
- return cast(str, request.remote_addr)
-
-
- def generate_text_hash(text: str) -> str:
- hash_text = str(text) + "None"
- return sha256(hash_text.encode()).hexdigest()
-
-
- def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
- if isinstance(response, dict):
- return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json")
- else:
-
- def generate() -> Generator:
- yield from response
-
- return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
-
-
- def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
- """
- This function is used to return a response with a length prefix.
- Magic number is a one byte number that indicates the type of the response.
-
- For a compatibility with latest plugin daemon https://github.com/langgenius/dify-plugin-daemon/pull/341
- Avoid using line-based response, it leads a memory issue.
-
- We uses following format:
- | Field | Size | Description |
- |---------------|----------|---------------------------------|
- | Magic Number | 1 byte | Magic number identifier |
- | Reserved | 1 byte | Reserved field |
- | Header Length | 2 bytes | Header length (usually 0xa) |
- | Data Length | 4 bytes | Length of the data |
- | Reserved | 6 bytes | Reserved fields |
- | Data | Variable | Actual data content |
-
- | Reserved Fields | Header | Data |
- |-----------------|----------|----------|
- | 4 bytes total | Variable | Variable |
-
- all data is in little endian
- """
-
- def pack_response_with_length_prefix(response: bytes) -> bytes:
- header_length = 0xA
- data_length = len(response)
- # | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
- return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
-
- if isinstance(response, dict):
- return Response(
- response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
- status=200,
- mimetype="application/json",
- )
- elif isinstance(response, BaseModel):
- return Response(
- response=pack_response_with_length_prefix(response.model_dump_json().encode("utf-8")),
- status=200,
- mimetype="application/json",
- )
-
- def generate() -> Generator:
- for chunk in response:
- if isinstance(chunk, str):
- yield pack_response_with_length_prefix(chunk.encode("utf-8"))
- else:
- yield pack_response_with_length_prefix(chunk)
-
- return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
-
-
- class TokenManager:
- @classmethod
- def generate_token(
- cls,
- token_type: str,
- account: Optional["Account"] = None,
- email: Optional[str] = None,
- additional_data: Optional[dict] = None,
- ) -> str:
- if account is None and email is None:
- raise ValueError("Account or email must be provided")
-
- account_id = account.id if account else None
- account_email = account.email if account else email
-
- if account_id:
- old_token = cls._get_current_token_for_account(account_id, token_type)
- if old_token:
- if isinstance(old_token, bytes):
- old_token = old_token.decode("utf-8")
- cls.revoke_token(old_token, token_type)
-
- token = str(uuid.uuid4())
- token_data = {"account_id": account_id, "email": account_email, "token_type": token_type}
- if additional_data:
- token_data.update(additional_data)
-
- expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES")
- if expiry_minutes is None:
- raise ValueError(f"Expiry minutes for {token_type} token is not set")
- token_key = cls._get_token_key(token, token_type)
- expiry_seconds = int(expiry_minutes * 60)
- redis_client.setex(token_key, expiry_seconds, json.dumps(token_data))
-
- if account_id:
- cls._set_current_token_for_account(account_id, token, token_type, expiry_minutes)
-
- return token
-
- @classmethod
- def _get_token_key(cls, token: str, token_type: str) -> str:
- return f"{token_type}:token:{token}"
-
- @classmethod
- def revoke_token(cls, token: str, token_type: str):
- token_key = cls._get_token_key(token, token_type)
- redis_client.delete(token_key)
-
- @classmethod
- def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]:
- key = cls._get_token_key(token, token_type)
- token_data_json = redis_client.get(key)
- if token_data_json is None:
- logger.warning("%s token %s not found with key %s", token_type, token, key)
- return None
- token_data: Optional[dict[str, Any]] = json.loads(token_data_json)
- return token_data
-
- @classmethod
- def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]:
- key = cls._get_account_token_key(account_id, token_type)
- current_token: Optional[str] = redis_client.get(key)
- return current_token
-
- @classmethod
- def _set_current_token_for_account(
- cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float]
- ):
- key = cls._get_account_token_key(account_id, token_type)
- expiry_seconds = int(expiry_minutes * 60)
- redis_client.setex(key, expiry_seconds, token)
-
- @classmethod
- def _get_account_token_key(cls, account_id: str, token_type: str) -> str:
- return f"{token_type}:account:{account_id}"
-
-
- class RateLimiter:
- def __init__(self, prefix: str, max_attempts: int, time_window: int):
- self.prefix = prefix
- self.max_attempts = max_attempts
- self.time_window = time_window
-
- def _get_key(self, email: str) -> str:
- return f"{self.prefix}:{email}"
-
- def is_rate_limited(self, email: str) -> bool:
- key = self._get_key(email)
- current_time = int(time.time())
- window_start_time = current_time - self.time_window
-
- redis_client.zremrangebyscore(key, "-inf", window_start_time)
- attempts = redis_client.zcard(key)
-
- if attempts and int(attempts) >= self.max_attempts:
- return True
- return False
-
- def increment_rate_limit(self, email: str):
- key = self._get_key(email)
- current_time = int(time.time())
-
- redis_client.zadd(key, {current_time: current_time})
- redis_client.expire(key, self.time_window * 2)
|