diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 10d86d17627048..cb3a6a5ea9b24f 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from configs import dify_config +from core.db.session_factory import session_factory from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle from core.entities.provider_entities import ( @@ -100,164 +101,173 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: :param tenant_id: :return: """ - # Get all provider records of the workspace - provider_name_to_provider_records_dict = self._get_all_providers(tenant_id) + with session_factory.create_session() as session: + # Get all provider records of the workspace + provider_name_to_provider_records_dict = self._get_all_providers(session, tenant_id) - # Initialize trial provider records if not exist - provider_name_to_provider_records_dict = self._init_trial_provider_records( - tenant_id, provider_name_to_provider_records_dict - ) + # Initialize trial provider records if not exist + provider_name_to_provider_records_dict = self._init_trial_provider_records( + tenant_id, provider_name_to_provider_records_dict + ) - # append providers with langgenius/openai/openai - provider_name_list = list(provider_name_to_provider_records_dict.keys()) - for provider_name in provider_name_list: - provider_id = ModelProviderID(provider_name) - if str(provider_id) not in provider_name_list: - provider_name_to_provider_records_dict[str(provider_id)] = provider_name_to_provider_records_dict[ - provider_name - ] - - # Get all provider model records of the workspace - provider_name_to_provider_model_records_dict = self._get_all_provider_models(tenant_id) - for provider_name in list(provider_name_to_provider_model_records_dict.keys()): - provider_id = ModelProviderID(provider_name) - if str(provider_id) not in provider_name_to_provider_model_records_dict: - provider_name_to_provider_model_records_dict[str(provider_id)] = ( - provider_name_to_provider_model_records_dict[provider_name] - ) + # append providers with langgenius/openai/openai + provider_name_list = list(provider_name_to_provider_records_dict.keys()) + for provider_name in provider_name_list: + provider_id = ModelProviderID(provider_name) + if str(provider_id) not in provider_name_list: + provider_name_to_provider_records_dict[str(provider_id)] = provider_name_to_provider_records_dict[ + provider_name + ] + + # Get all provider model records of the workspace + provider_name_to_provider_model_records_dict = self._get_all_provider_models(session, tenant_id) + for provider_name in list(provider_name_to_provider_model_records_dict.keys()): + provider_id = ModelProviderID(provider_name) + if str(provider_id) not in provider_name_to_provider_model_records_dict: + provider_name_to_provider_model_records_dict[str(provider_id)] = ( + provider_name_to_provider_model_records_dict[provider_name] + ) - # Get all provider entities - model_provider_factory = ModelProviderFactory(tenant_id) - provider_entities = model_provider_factory.get_providers() - - # Get All preferred provider types of the workspace - provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) - # Ensure that both the original provider name and its ModelProviderID string representation - # are present in the dictionary to handle cases where either form might be used - for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()): - provider_id = ModelProviderID(provider_name) - if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict: - # Add the ModelProviderID string representation if it's not already present - provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = ( - provider_name_to_preferred_model_provider_records_dict[provider_name] - ) + # Get all provider entities + model_provider_factory = ModelProviderFactory(tenant_id) + provider_entities = model_provider_factory.get_providers() - # Get All provider model settings - provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) + # Get All preferred provider types of the workspace + provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers( + session, tenant_id + ) + # Ensure that both the original provider name and its ModelProviderID string representation + # are present in the dictionary to handle cases where either form might be used + for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()): + provider_id = ModelProviderID(provider_name) + if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict: + # Add the ModelProviderID string representation if it's not already present + provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = ( + provider_name_to_preferred_model_provider_records_dict[provider_name] + ) - # Get All load balancing configs - provider_name_to_provider_load_balancing_model_configs_dict = self._get_all_provider_load_balancing_configs( - tenant_id - ) + # Get All provider model settings + provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(session, tenant_id) - # Get All provider model credentials - provider_name_to_provider_model_credentials_dict = self._get_all_provider_model_credentials(tenant_id) + # Get All load balancing configs + provider_name_to_provider_load_balancing_model_configs_dict = self._get_all_provider_load_balancing_configs( + session, tenant_id + ) - provider_configurations = ProviderConfigurations(tenant_id=tenant_id) + # Get All provider model credentials + provider_name_to_provider_model_credentials_dict = self._get_all_provider_model_credentials( + session, tenant_id + ) - # Construct ProviderConfiguration objects for each provider - for provider_entity in provider_entities: - # handle include, exclude - if is_filtered( - include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, - exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, - data=provider_entity, - name_func=lambda x: x.provider, - ): - continue + provider_configurations = ProviderConfigurations(tenant_id=tenant_id) + + # Construct ProviderConfiguration objects for each provider + for provider_entity in provider_entities: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, + exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + data=provider_entity, + name_func=lambda x: x.provider, + ): + continue - provider_name = provider_entity.provider - provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) - provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) - provider_id_entity = ModelProviderID(provider_name) - if provider_id_entity.is_langgenius(): - provider_model_records.extend( - provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, []) - ) - provider_model_credentials = provider_name_to_provider_model_credentials_dict.get( - provider_entity.provider, [] - ) - provider_id_entity = ModelProviderID(provider_name) - if provider_id_entity.is_langgenius(): - provider_model_credentials.extend( - provider_name_to_provider_model_credentials_dict.get(provider_id_entity.provider_name, []) + provider_name = provider_entity.provider + provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) + provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) + provider_id_entity = ModelProviderID(provider_name) + if provider_id_entity.is_langgenius(): + provider_model_records.extend( + provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, []) + ) + provider_model_credentials = provider_name_to_provider_model_credentials_dict.get( + provider_entity.provider, [] ) + provider_id_entity = ModelProviderID(provider_name) + if provider_id_entity.is_langgenius(): + provider_model_credentials.extend( + provider_name_to_provider_model_credentials_dict.get(provider_id_entity.provider_name, []) + ) - # Convert to custom configuration - custom_configuration = self._to_custom_configuration( - tenant_id, provider_entity, provider_records, provider_model_records, provider_model_credentials - ) + # Convert to custom configuration + custom_configuration = self._to_custom_configuration( + tenant_id, provider_entity, provider_records, provider_model_records, provider_model_credentials + ) - # Convert to system configuration - system_configuration = self._to_system_configuration(tenant_id, provider_entity, provider_records) + # Convert to system configuration + system_configuration = self._to_system_configuration(tenant_id, provider_entity, provider_records) - # Get preferred provider type - preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name) + # Get preferred provider type + preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get( + provider_name + ) - if preferred_provider_type_record: - preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type) - elif custom_configuration.provider or custom_configuration.models: - preferred_provider_type = ProviderType.CUSTOM - elif system_configuration.enabled: - preferred_provider_type = ProviderType.SYSTEM - else: - preferred_provider_type = ProviderType.CUSTOM + if preferred_provider_type_record: + preferred_provider_type = ProviderType.value_of( + preferred_provider_type_record.preferred_provider_type + ) + elif custom_configuration.provider or custom_configuration.models: + preferred_provider_type = ProviderType.CUSTOM + elif system_configuration.enabled: + preferred_provider_type = ProviderType.SYSTEM + else: + preferred_provider_type = ProviderType.CUSTOM - using_provider_type = preferred_provider_type - has_valid_quota = any(quota_conf.is_valid for quota_conf in system_configuration.quota_configurations) + using_provider_type = preferred_provider_type + has_valid_quota = any(quota_conf.is_valid for quota_conf in system_configuration.quota_configurations) - if preferred_provider_type == ProviderType.SYSTEM: - if not system_configuration.enabled or not has_valid_quota: - using_provider_type = ProviderType.CUSTOM + if preferred_provider_type == ProviderType.SYSTEM: + if not system_configuration.enabled or not has_valid_quota: + using_provider_type = ProviderType.CUSTOM - else: - if not custom_configuration.provider and not custom_configuration.models: - if system_configuration.enabled and has_valid_quota: - using_provider_type = ProviderType.SYSTEM + else: + if not custom_configuration.provider and not custom_configuration.models: + if system_configuration.enabled and has_valid_quota: + using_provider_type = ProviderType.SYSTEM - # Get provider load balancing configs - provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name) + # Get provider load balancing configs + provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name) - # Get provider load balancing configs - provider_load_balancing_configs = provider_name_to_provider_load_balancing_model_configs_dict.get( - provider_name - ) + # Get provider load balancing configs + provider_load_balancing_configs = provider_name_to_provider_load_balancing_model_configs_dict.get( + provider_name + ) - provider_id_entity = ModelProviderID(provider_name) + provider_id_entity = ModelProviderID(provider_name) - if provider_id_entity.is_langgenius(): - if provider_model_settings is not None: - provider_model_settings.extend( - provider_name_to_provider_model_settings_dict.get(provider_id_entity.provider_name, []) - ) - if provider_load_balancing_configs is not None: - provider_load_balancing_configs.extend( - provider_name_to_provider_load_balancing_model_configs_dict.get( - provider_id_entity.provider_name, [] + if provider_id_entity.is_langgenius(): + if provider_model_settings is not None: + provider_model_settings.extend( + provider_name_to_provider_model_settings_dict.get(provider_id_entity.provider_name, []) + ) + if provider_load_balancing_configs is not None: + provider_load_balancing_configs.extend( + provider_name_to_provider_load_balancing_model_configs_dict.get( + provider_id_entity.provider_name, [] + ) ) - ) - # Convert to model settings - model_settings = self._to_model_settings( - provider_entity=provider_entity, - provider_model_settings=provider_model_settings, - load_balancing_model_configs=provider_load_balancing_configs, - ) + # Convert to model settings + model_settings = self._to_model_settings( + provider_entity=provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=provider_load_balancing_configs, + ) - provider_configuration = ProviderConfiguration( - tenant_id=tenant_id, - provider=provider_entity, - preferred_provider_type=preferred_provider_type, - using_provider_type=using_provider_type, - system_configuration=system_configuration, - custom_configuration=custom_configuration, - model_settings=model_settings, - ) + provider_configuration = ProviderConfiguration( + tenant_id=tenant_id, + provider=provider_entity, + preferred_provider_type=preferred_provider_type, + using_provider_type=using_provider_type, + system_configuration=system_configuration, + custom_configuration=custom_configuration, + model_settings=model_settings, + ) - provider_configurations[str(provider_id_entity)] = provider_configuration + provider_configurations[str(provider_id_entity)] = provider_configuration - # Return the encapsulated object - return provider_configurations + # Return the encapsulated object + return provider_configurations def get_provider_model_bundle(self, tenant_id: str, provider: str, model_type: ModelType) -> ProviderModelBundle: """ @@ -402,18 +412,19 @@ def update_default_model_record( return default_model @staticmethod - def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: + def _get_all_providers(session: Session, tenant_id: str) -> dict[str, list[Provider]]: provider_name_to_provider_records_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True) - providers = session.scalars(stmt) - for provider in providers: - # Use provider name with prefix after the data migration - provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) + stmt = select(Provider).where(Provider.tenant_id == tenant_id) + providers = session.scalars(stmt) + for provider in providers: + if not provider.is_valid: + continue + # Use provider name with prefix after the data migration + provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) return provider_name_to_provider_records_dict @staticmethod - def _get_all_provider_models(tenant_id: str) -> dict[str, list[ProviderModel]]: + def _get_all_provider_models(session: Session, tenant_id: str) -> dict[str, list[ProviderModel]]: """ Get all provider model records of the workspace. @@ -421,15 +432,16 @@ def _get_all_provider_models(tenant_id: str) -> dict[str, list[ProviderModel]]: :return: """ provider_name_to_provider_model_records_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) - provider_models = session.scalars(stmt) - for provider_model in provider_models: - provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) + stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id) + provider_models = session.scalars(stmt) + for provider_model in provider_models: + if not provider_model.is_valid: + continue + provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) return provider_name_to_provider_model_records_dict @staticmethod - def _get_all_preferred_model_providers(tenant_id: str) -> dict[str, TenantPreferredModelProvider]: + def _get_all_preferred_model_providers(session: Session, tenant_id: str) -> dict[str, TenantPreferredModelProvider]: """ Get All preferred provider types of the workspace. @@ -437,17 +449,16 @@ def _get_all_preferred_model_providers(tenant_id: str) -> dict[str, TenantPrefer :return: """ provider_name_to_preferred_provider_type_records_dict = {} - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id) - preferred_provider_types = session.scalars(stmt) - provider_name_to_preferred_provider_type_records_dict = { - preferred_provider_type.provider_name: preferred_provider_type - for preferred_provider_type in preferred_provider_types - } + stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id) + preferred_provider_types = session.scalars(stmt) + provider_name_to_preferred_provider_type_records_dict = { + preferred_provider_type.provider_name: preferred_provider_type + for preferred_provider_type in preferred_provider_types + } return provider_name_to_preferred_provider_type_records_dict @staticmethod - def _get_all_provider_model_settings(tenant_id: str) -> dict[str, list[ProviderModelSetting]]: + def _get_all_provider_model_settings(session: Session, tenant_id: str) -> dict[str, list[ProviderModelSetting]]: """ Get All provider model settings of the workspace. @@ -455,17 +466,19 @@ def _get_all_provider_model_settings(tenant_id: str) -> dict[str, list[ProviderM :return: """ provider_name_to_provider_model_settings_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id) - provider_model_settings = session.scalars(stmt) - for provider_model_setting in provider_model_settings: - provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( - provider_model_setting - ) + stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id) + provider_model_settings = session.scalars(stmt) + for provider_model_setting in provider_model_settings: + provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( + provider_model_setting + ) + return provider_name_to_provider_model_settings_dict @staticmethod - def _get_all_provider_model_credentials(tenant_id: str) -> dict[str, list[ProviderModelCredential]]: + def _get_all_provider_model_credentials( + session: Session, tenant_id: str + ) -> dict[str, list[ProviderModelCredential]]: """ Get All provider model credentials of the workspace. @@ -473,17 +486,18 @@ def _get_all_provider_model_credentials(tenant_id: str) -> dict[str, list[Provid :return: """ provider_name_to_provider_model_credentials_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id) - provider_model_credentials = session.scalars(stmt) - for provider_model_credential in provider_model_credentials: - provider_name_to_provider_model_credentials_dict[provider_model_credential.provider_name].append( - provider_model_credential - ) + stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id) + provider_model_credentials = session.scalars(stmt) + for provider_model_credential in provider_model_credentials: + provider_name_to_provider_model_credentials_dict[provider_model_credential.provider_name].append( + provider_model_credential + ) return provider_name_to_provider_model_credentials_dict @staticmethod - def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: + def _get_all_provider_load_balancing_configs( + session: Session, tenant_id: str + ) -> dict[str, list[LoadBalancingModelConfig]]: """ Get All provider load balancing configs of the workspace. @@ -503,13 +517,12 @@ def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[L return {} provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id) - provider_load_balancing_configs = session.scalars(stmt) - for provider_load_balancing_config in provider_load_balancing_configs: - provider_name_to_provider_load_balancing_model_configs_dict[ - provider_load_balancing_config.provider_name - ].append(provider_load_balancing_config) + stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id) + provider_load_balancing_configs = session.scalars(stmt) + for provider_load_balancing_config in provider_load_balancing_configs: + provider_name_to_provider_load_balancing_model_configs_dict[ + provider_load_balancing_config.provider_name + ].append(provider_load_balancing_config) return provider_name_to_provider_load_balancing_model_configs_dict