django saas多租户 分库分表 一个租户一个数据库
2024-03-20 本文已影响0人
张晓畅
- 企业(租户)应用,一个企业算一个租户
|----其他应用
|----organize
|----management 自定义迁移命令
|----commands
|-----__inint__.py
|----clearmigrations.py 清除所有迁移
|----organize_migrate.py 企业数据库迁移
|----patch
|----__init__.py
|----connection.py 添加企业数据库连接setting.DATABASES
|----contenttype.py 重构contenttype
|----permission.py 重构权限
|----request.py 请求接口根据用户企业id设置数据库连接
|----app.py 应用初始化
|----signal.py 信号,创建企业的时候触发创建数据库
|----其他文件
多数据流程:
-、以企业id小写创建数据库
-、创建企业的时候通过信号触发create_database创建企业数据库
-、应用数据迁移由DATABASE_APPS_MAPPING决定哪个对应的应用
-、__connection_handler__getitem__会把不存在DATABASES的数据库添加的连接配置里
-、数据库读写由DBRouter路由
-、请求接口会根据request.user的organize id设置对应的数据库
-、设置数据set_current_db
-、获取数据库get_current_db,默认是default数据库
-、企业获取模型所有数据get_model_all_data
-、创建数据会根据新数据带的organize id指定数据库
-、获取数据api_view.get_queryset(自定义视图)会查询对应的数据库
-、数据库迁移命令:
迁移所有企业(不包括default):python manage.py organize_migrations --all
迁移xx企业(xx=>企业id小写):python manage.py organize_migrations --database=xxx
- setting.py 配置
INSTALLED_APPS = [
# 共用应用
‘organize’,
...
# 租户应用
'organizeapp'
...
]
"""
定制化应用,给租户专有的应用指定数据库,租户共有的应用不需要指定,
不属于租户的应用都要指定到默认default数据库
"""
DATABASE_APPS_MAPPING = {
'organize': 'default'
}
# 数据库
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.mysql',
'NAME': 'databasename',
'USER': 'user',
'PASSWORD': 'password',
'HOST': '127.0.0.1',
'PORT': '3306'
}
}
# 数据库路由
DATABASE_ROUTERS = ['utils.api_db.DBRouter']
- api_db.py 数据库处理
import copy
from django.conf import settings
from django.db.models import Model, Q
from threading import local, Thread
from django.db.utils import load_backend
from django.apps import apps as django_apps
_thread_local = local()
def create_database(org_id: str, org_name: str):
"""创建企业数据库"""
# 创建连接
db_config = get_default_db_config(None)
backend = load_backend(db_config["ENGINE"])
connection = backend.DatabaseWrapper(db_config, org_id)
# 获取游标,执行sql语句,用企业id(自定义生成id)作为数据库名称
sql = f"CREATE DATABASE IF NOT EXISTS {org_id} character set utf8;"
with connection.cursor() as cursor:
cursor.execute(sql)
# 数据迁移
thread = Thread(target=migrate, args=[org_id, org_name])
thread.start()
# execute_from_command_line(['manage.py', 'migrate', f'--database={self.id}'])
# logger.info('migrate successfuly!')
return True
def migrate(org_id: str, org_name: str):
"""数据迁移"""
try:
from django.core.management import execute_from_command_line
except ImportError as exc:
raise ImportError(
"Couldn't import Django. Are you sure it's installed and "
"available on your PYTHONPATH environment variable? Did you "
"forget to activate a virtual environment?"
) from exc
execute_from_command_line(['manage.py', 'migrate', f'--database={org_id}'])
def get_model_all_data(queryset=None, filter_conditions=None, view=None, organize=''):
"""
获取模型的所有数据
分两种情况:
1:视图(接口)获取数据 传入view
2: 非视图获取数据 传入queryset,和过滤条件filter_conditions
所有数据按创建时间和id降序后返回
如果指定了企业(id),则直接查询该企业的数据
"""
if organize != '':
res_data = get_db_data(queryset=queryset, filter_conditions=filter_conditions, view=view, organize=organize)
else:
Organize = get_model('organize.Organize')
org_ids = Organize.objects.values_list('id', flat=True)
res_data = list()
for pk in org_ids:
org_data = get_db_data(queryset=queryset, filter_conditions=filter_conditions, view=view, organize=pk)
res_data.extend(org_data)
# 按创建时间和id降序,如果创建时间相同,则按id降序
res_data.sort(key=lambda x: (x['create_time'], x['id']), reverse=True)
return res_data
def get_db_data(queryset=None, filter_conditions=None, view=None, organize=None):
"""
获取数据库数据
分两种情况:
1:视图(接口)获取数据 传入view
2: 非视图获取数据 传入queryset,和过滤条件filter_conditions
所有数据按创建时间和id降序后返回
"""
# 设置数据库
db_name = organize.lower()
set_current_db(db_name)
# 过滤条件
conditions = []
if filter_conditions:
conditions = [Q(**item) for item in filter_conditions]
if view is not None:
# 视图的数据需要过滤和序列化
org_queryset = view.filter_queryset(view.get_queryset()) \
.order_by('-create_time', '-id')
serializer = view.get_serializer(instance=org_queryset, many=True)
org_data = serializer.data
else:
# 非视图数据按条件过滤
org_data = queryset.filter(*conditions) \
.order_by('-create_time', '-id').values()
org_data = list(org_data)
return org_data
def get_current_db():
return getattr(_thread_local, 'db_name', 'default')
def set_current_db(db_name):
setattr(_thread_local, 'db_name', db_name.lower())
def get_default_db_config(db=None):
"""获取默认数据库配置"""
db_config = copy.deepcopy(settings.DATABASES.get('default', None))
if db_config is None:
raise KeyError('获取不到数据库配置')
db_config['NAME'] = None if db is None else db.lower()
return db_config
class DBRouter:
"""数据库读写"""
def db_for_read(self, model: Model, **hints) -> str:
if model._meta.app_label in settings.DATABASE_APPS_MAPPING:
return settings.DATABASE_APPS_MAPPING[model._meta.app_label]
return get_current_db()
def db_for_write(self, model: Model, **hints):
if model._meta.app_label in settings.DATABASE_APPS_MAPPING:
return settings.DATABASE_APPS_MAPPING[model._meta.app_label]
return get_current_db()
def allow_migrate(self, db: str, app_label: str, **hints) -> bool:
"""平台只迁移default数据库的app,企业只迁移租户(非default)的app"""
if app_label == 'contenttypes':
return True
app_db = settings.DATABASE_APPS_MAPPING.get(app_label, None)
if app_db == 'default' and db == 'default':
return True
elif app_db != 'default' and db != 'default':
return True
else:
return False
def get_custom_id(model, pre_str, num):
"""
获取自定义id 先从redis获取最后一个id,如果是None则从数据库获取
查询结果会去重,必须保证所有数据的id都不一样
:param model: 模型
:param pre_str: id的前缀
:param num: id预设长度
:return: str 自定义id
"""
redis_con = get_redis_connection("default")
last_id = redis_con.hget('E:ADMIN:LASTID', model._meta.app_label)
if last_id is None:
try:
last_data = model.objects.using('default').order_by('create_time', 'id').last()
last_id = last_data.id
except:
last_data = get_model_all_data(model.objects.all())
last_id = last_data[0]['id']
if last_id is None or not last_id.startswith(pre_str):
# 第一个以EID开头的id
last_id = 0
else:
# 在已有的id上增加
last_id = int(re.sub(r'\D+0*', '', last_id))
else:
last_id = int(last_id)
new_id = str(last_id + 1)
redis_con.hset('E:ADMIN:LASTID', model._meta.app_label, new_id)
# 补零个数
zero_num = 0 if len(new_id) > num else num - len(new_id)
return pre_str + '0' * zero_num + new_id
- clearmigrations.py 清除所有迁移
import os
import shutil
from django.core.management.base import BaseCommand
from django.apps import apps
class Command(BaseCommand):
help = "Clear all migration files in project."
def get_apps(self):
for app in apps.get_app_configs():
# path = os.path.join(
# settings.BASE_DIR, app.name.replace(".", "/"), "migrations"
# )
path = os.path.join(app.path, "migrations")
if os.path.exists(path):
yield app, path
def handle(self, *args, **options):
for app, path in self.get_apps():
# 递归删除所有文件及其路径(删除整个文件夹)
shutil.rmtree(path)
# 创建文件夹,并生成__init__.py文件
os.makedirs(path)
with open(os.path.join(path, "__init__.py"), "w+") as file:
file.write("")
self.stdout.write(self.style.SUCCESS(f"Clear {path}"))
self.stdout.write(self.style.SUCCESS("Successfully cleared!"))
- organize_migrate.py 企业数据库迁移
import sys
from importlib import import_module
from django.apps import apps
from django.core.management.base import CommandError, no_translations
from django.core.management.sql import (
emit_post_migrate_signal, emit_pre_migrate_signal,
)
from django.db import connections
from django.db.migrations.autodetector import MigrationAutodetector
from django.db.migrations.executor import MigrationExecutor
from django.db.migrations.loader import AmbiguityError
from django.db.migrations.state import ModelState, ProjectState
from django.utils.module_loading import module_has_submodule
from django.core.management.commands.migrate import Command as MigrateCommand
class Command(MigrateCommand):
"""
迁移企业数据库
api_db.DBRouter.allow_migrate会将模型迁移到对应的数据中去
迁移xxx企业:python manage.py organize_migrate --database xxx
迁移所有企业:python manage.py organize_migrate --all
"""
help = "organize migrate database."
requires_system_checks = []
def add_arguments(self, parser):
parser.add_argument(
'--all', action='store_true',
help='migrate all database.',
)
super().add_arguments(parser)
@no_translations
def handle(self, *args, **options):
if options['all']:
# 迁移所有企业
from organize import get_all_organize_db
dbs = list(get_all_organize_db().keys())
for key in dbs:
connections[key]
self.execute_migrate(key, args, options)
else:
# 迁移xxx企业
cur_database = options['database']
if not cur_database in connections:
connections[cur_database]
self.execute_migrate(cur_database, args, options)
def execute_migrate(self, cur_database, args, options):
if not options['skip_checks']:
self.check(databases=[cur_database])
self.verbosity = options['verbosity']
self.interactive = options['interactive']
# Import the 'management' module within each installed app, to register
# dispatcher events.
for app_config in apps.get_app_configs():
if module_has_submodule(app_config.module, "management"):
import_module('.management', app_config.name)
# Get the database we're operating from
self.stdout.write(f'\nMigrate Database: {cur_database}')
connection = connections[cur_database]
# Hook for backends needing any database preparation
connection.prepare_database()
# Work out which apps have migrations and which do not
executor = MigrationExecutor(connection, self.migration_progress_callback)
# Raise an error if any migrations are applied before their dependencies.
executor.loader.check_consistent_history(connection)
# Before anything else, see if there's conflicting apps and drop out
# hard if there are any
conflicts = executor.loader.detect_conflicts()
if conflicts:
name_str = "; ".join(
"%s in %s" % (", ".join(names), app)
for app, names in conflicts.items()
)
raise CommandError(
"Conflicting migrations detected; multiple leaf nodes in the "
"migration graph: (%s).\nTo fix them run "
"'python manage.py makemigrations --merge'" % name_str
)
# If they supplied command line arguments, work out what they mean.
run_syncdb = options['run_syncdb']
target_app_labels_only = True
if options['app_label']:
# Validate app_label.
app_label = options['app_label']
try:
apps.get_app_config(app_label)
except LookupError as err:
raise CommandError(str(err))
if run_syncdb:
if app_label in executor.loader.migrated_apps:
raise CommandError("Can't use run_syncdb with app '%s' as it has migrations." % app_label)
elif app_label not in executor.loader.migrated_apps:
raise CommandError("App '%s' does not have migrations." % app_label)
if options['app_label'] and options['migration_name']:
migration_name = options['migration_name']
if migration_name == "zero":
targets = [(app_label, None)]
else:
try:
migration = executor.loader.get_migration_by_prefix(app_label, migration_name)
except AmbiguityError:
raise CommandError(
"More than one migration matches '%s' in app '%s'. "
"Please be more specific." %
(migration_name, app_label)
)
except KeyError:
raise CommandError("Cannot find a migration matching '%s' from app '%s'." % (
migration_name, app_label))
targets = [(app_label, migration.name)]
target_app_labels_only = False
elif options['app_label']:
targets = [key for key in executor.loader.graph.leaf_nodes() if key[0] == app_label]
else:
targets = executor.loader.graph.leaf_nodes()
plan = executor.migration_plan(targets)
exit_dry = plan and options['check_unapplied']
if options['plan']:
self.stdout.write('Planned operations:', self.style.MIGRATE_LABEL)
if not plan:
self.stdout.write(' No planned migration operations.')
for migration, backwards in plan:
self.stdout.write(str(migration), self.style.MIGRATE_HEADING)
for operation in migration.operations:
message, is_error = self.describe_operation(operation, backwards)
style = self.style.WARNING if is_error else None
self.stdout.write(' ' + message, style)
if exit_dry:
sys.exit(1)
return
if exit_dry:
sys.exit(1)
# At this point, ignore run_syncdb if there aren't any apps to sync.
run_syncdb = options['run_syncdb'] and executor.loader.unmigrated_apps
# Print some useful info
if self.verbosity >= 1:
self.stdout.write(self.style.MIGRATE_HEADING("Operations to perform:"))
if run_syncdb:
if options['app_label']:
self.stdout.write(
self.style.MIGRATE_LABEL(" Synchronize unmigrated app: %s" % app_label)
)
else:
self.stdout.write(
self.style.MIGRATE_LABEL(" Synchronize unmigrated apps: ") +
(", ".join(sorted(executor.loader.unmigrated_apps)))
)
if target_app_labels_only:
self.stdout.write(
self.style.MIGRATE_LABEL(" Apply all migrations: ") +
(", ".join(sorted({a for a, n in targets})) or "(none)")
)
else:
if targets[0][1] is None:
self.stdout.write(
self.style.MIGRATE_LABEL(' Unapply all migrations: ') +
str(targets[0][0])
)
else:
self.stdout.write(self.style.MIGRATE_LABEL(
" Target specific migration: ") + "%s, from %s"
% (targets[0][1], targets[0][0])
)
pre_migrate_state = executor._create_project_state(with_applied_migrations=True)
pre_migrate_apps = pre_migrate_state.apps
emit_pre_migrate_signal(
self.verbosity, self.interactive, connection.alias, apps=pre_migrate_apps, plan=plan,
)
# Run the syncdb phase.
if run_syncdb:
if self.verbosity >= 1:
self.stdout.write(self.style.MIGRATE_HEADING("Synchronizing apps without migrations:"))
if options['app_label']:
self.sync_apps(connection, [app_label])
else:
self.sync_apps(connection, executor.loader.unmigrated_apps)
# Migrate!
if self.verbosity >= 1:
self.stdout.write(self.style.MIGRATE_HEADING("Running migrations:"))
if not plan:
if self.verbosity >= 1:
self.stdout.write(" No migrations to apply.")
# If there's changes that aren't in migrations yet, tell them how to fix it.
autodetector = MigrationAutodetector(
executor.loader.project_state(),
ProjectState.from_apps(apps),
)
changes = autodetector.changes(graph=executor.loader.graph)
if changes:
self.stdout.write(self.style.NOTICE(
" Your models in app(s): %s have changes that are not "
"yet reflected in a migration, and so won't be "
"applied." % ", ".join(repr(app) for app in sorted(changes))
))
self.stdout.write(self.style.NOTICE(
" Run 'manage.py makemigrations' to make new "
"migrations, and then re-run 'manage.py migrate' to "
"apply them."
))
fake = False
fake_initial = False
else:
fake = options['fake']
fake_initial = options['fake_initial']
post_migrate_state = executor.migrate(
targets, plan=plan, state=pre_migrate_state.clone(), fake=fake,
fake_initial=fake_initial,
)
# post_migrate signals have access to all models. Ensure that all models
# are reloaded in case any are delayed.
post_migrate_state.clear_delayed_apps_cache()
post_migrate_apps = post_migrate_state.apps
# Re-render models of real apps to include relationships now that
# we've got a final state. This wouldn't be necessary if real apps
# models were rendered with relationships in the first place.
with post_migrate_apps.bulk_update():
model_keys = []
for model_state in post_migrate_apps.real_models:
model_key = model_state.app_label, model_state.name_lower
model_keys.append(model_key)
post_migrate_apps.unregister_model(*model_key)
post_migrate_apps.render_multiple([
ModelState.from_model(apps.get_model(*model)) for model in model_keys
])
# Send the post_migrate signal, so individual apps can do whatever they need
# to do at this point.
emit_post_migrate_signal(
self.verbosity, self.interactive, connection.alias, apps=post_migrate_apps, plan=plan,
)
- connection.py 添加企业数据库连接setting.DATABASES
import logging
from django.db.utils import ConnectionHandler
from utils.api_db import get_default_db_config
logger = logging.getLogger('django')
def __connection_handler__getitem__(self, alias: str) -> ConnectionHandler:
if isinstance(alias, str):
try:
return getattr(self._connections, alias)
except AttributeError:
if alias not in self.settings:
organize_db = get_default_db_config(alias)
if organize_db:
# 添加没有配置的企业数据库
self.settings[alias] = organize_db
else:
logger.error(f"The connection '{alias}' doesn't exist.")
raise self.exception_class(f"The connection '{alias}' doesn't exist.")
conn = self.create_connection(alias)
setattr(self._connections, alias, conn)
return conn
else:
logger.error(f'The connection alias [{alias}] must be string')
raise Exception(f'The connection alias [{alias}] must be string')
ConnectionHandler.__getitem__ = __connection_handler__getitem__
- contenttype.py 重构contenttype
from django.apps import apps as global_apps
from django.contrib.contenttypes import management
from django.contrib.contenttypes.management import get_contenttypes_and_models
from django.db import DEFAULT_DB_ALIAS
from organize import get_common_apps
def create_contenttypes(
app_config,
verbosity=2,
interactive=True,
using=DEFAULT_DB_ALIAS,
apps=global_apps,
**kwargs
):
"""
Create content types for models in the given app.
"""
if not app_config.models_module:
return
app_label = app_config.label
try:
app_config = apps.get_app_config(app_label)
ContentType = apps.get_model('contenttypes', 'ContentType')
except LookupError:
return
common_applist = get_common_apps()
if app_config.name in common_applist:
return None
content_types, app_models = get_contenttypes_and_models(app_config, using, ContentType)
if not app_models:
return
cts = [
ContentType(
app_label=app_label,
model=model_name,
)
for (model_name, model) in app_models.items()
if model_name not in content_types
]
ContentType.objects.using(using).bulk_create(cts)
if verbosity >= 2:
for ct in cts:
print("Adding content type '%s | %s'" % (ct.app_label, ct.model))
management.create_contenttypes = create_contenttypes
- permission.py 重构权限
"""
Creates permissions for all installed apps that need permissions.
"""
from django.apps import apps as global_apps
from django.contrib.contenttypes.management import create_contenttypes
from django.db import DEFAULT_DB_ALIAS, router
from django.contrib.auth.models import Permission
from django.contrib.auth.management import _get_all_permissions
from django.contrib.auth import management
from django.contrib.auth import backends
from organize import get_common_apps
def create_permissions(app_config, verbosity=2, interactive=True, using=DEFAULT_DB_ALIAS, apps=global_apps, **kwargs):
if not app_config.models_module:
return
create_contenttypes(app_config, verbosity=verbosity, interactive=interactive, using=using, apps=apps, **kwargs)
app_label = app_config.label
try:
app_config = apps.get_app_config(app_label)
ContentType = apps.get_model('contenttypes', 'ContentType')
Permission = apps.get_model('auth', 'Permission')
except LookupError:
return
common_applist = get_common_apps()
if app_config.name in common_applist:
return
if not router.allow_migrate_model(using, Permission):
return
# This will hold the permissions we're looking for as
# (content_type, (codename, name))
searched_perms = []
# The codenames and ctypes that should exist.
ctypes = set()
for klass in app_config.get_models():
# Force looking up the content types in the current database
# before creating foreign keys to them.
ctype = ContentType.objects.db_manager(using).get_for_model(klass, for_concrete_model=False)
ctypes.add(ctype)
for perm in _get_all_permissions(klass._meta):
searched_perms.append((ctype, perm))
# Find all the Permissions that have a content_type for a model we're
# looking for. We don't need to check for codenames since we already have
# a list of the ones we're going to create.
all_perms = set(Permission.objects.using(using).filter(
content_type__in=ctypes,
).values_list(
"content_type", "codename"
))
perms = [
Permission(codename=codename, name=name, content_type_id=ct.id)
for ct, (codename, name) in searched_perms
if (ct.pk, codename) not in all_perms
]
Permission.objects.using(using).bulk_create(perms)
if verbosity >= 2:
for perm in perms:
print("Adding permission '%s'" % perm)
def _get_group_permissions(self, user_obj):
user_groups_field = user_obj._meta.get_field('groups')
user_groups_query = 'group__%s' % user_groups_field.related_query_name()
return Permission.objects.filter(**{user_groups_query: user_obj})
backends.ModelBackend._get_group_permissions = _get_group_permissions
management.create_permissions = create_permissions
- request.py 请求接口根据用户企业id设置数据库连接
from rest_framework.request import Request
from rest_framework import exceptions
from utils.api_db import set_current_db
def __request_authenticate(self):
"""
Attempt to authenticate the request using each authentication instance
in turn.
"""
for authenticator in self.authenticators:
try:
user_auth_tuple = authenticator.authenticate(self)
except exceptions.APIException:
self._not_authenticated()
raise
if user_auth_tuple is not None:
self._authenticator = authenticator
self.user, self.auth = user_auth_tuple
if self.user and self.user.organize:
# 设置对应的企业数据库
set_current_db(self.user.organize.id.lower())
return
self._not_authenticated()
Request._authenticate = __request_authenticate
- app.py 应用初始化
from django.apps import AppConfig
from django.db import connections
from organize import get_all_organize_db
class OrganizeConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'organize'
def ready(self) -> None:
from .signal import create_data_handler # 创建企业的时候创建数据库
from .patch.connection import ConnectionHandler
from .patch import request # 给请求接口设置数据库
from .patch.contenttype import management
from .patch.permission import management
dbs = list(get_all_organize_db().keys())
for db in dbs:
# 连接数据库
connections[db]
return super().ready()
- signal.py 信号,创建企业的时候触发创建数据库
from django.dispatch import Signal # 自定义信号
from django.dispatch import receiver
from utils.api_db import create_database
from utils.logger import logger
create_organize = Signal() # 创建企业
@receiver(create_organize)
def create_data_handler(sender, org_id, org_name, **kwargs):
try:
# 给企业创建数据库
create_database(org_id, org_name)
logger.info(f'create database : [{org_id}] successfuly for {org_name}')
except Exception as e:
logger.error(e)
- 其他
# 触发创建企业数据库信号
create_organize.send(sender=Organize, org_id=result['uid'].lower(), org_name=result['value'])
# api_view.get_queryset
class APIView(GenericViewSet)
def get_queryset(self):
db = get_current_db()
organize = self.request.query_params.get('organize', '') or self.request.data.get('organize', '')
if db == 'default' and organize is not None:
# 平台调用企业数据(单条)
set_current_db(organize.lower())
return super().get_queryset()
elif db != 'default' and self.basename in get_common_apps():
# 企业查询平台数据
set_current_db('default')
return super().get_queryset()
else:
return super().get_queryset()
class AllDataModelMixin(object):
"""平台获取所有企业数据"""
def all_data(self, request):
organize = request.query_params.get('organize', '')
res_data = get_model_all_data(view=self, organize=organize)
# 返回分页
page = self.paginator.paginate_queryset(res_data, request)
return self.paginator.get_paginated_response(page,)