11#!/usr/bin/env python
22# -*- coding: utf-8 -*-
33
4+ from typing import TypeVar , List , Dict
45from app .core .config import settings
5- from sqlalchemy import create_engine
6+ from sqlalchemy import create_engine , text
67from sqlalchemy .orm import sessionmaker
78from app .models .base import Model
89from app .models import system
910from pathlib import Path
1011import orjson
1112
1213
14+ ModelType = TypeVar ("ModelType" , bound = Model )
15+
16+
1317class InitializeData :
1418 """
1519 初始化数据
1620 """
1721
1822 SCRIPT_DIR : Path = Path .joinpath (settings .BASE_DIR , 'scripts' , 'initialize' )
1923
20- def __init__ (self ):
24+ def __init__ (self ) -> None :
2125 self .engine = create_engine (self .__get_db_url (), echo = True , future = True )
2226 self .DBSession = sessionmaker (bind = self .engine )
23-
24- def __get_db_url (self ):
27+ self .prepare_init_models = [
28+ system .DeptModel ,
29+ system .UserModel ,
30+ system .MenuModel ,
31+ system .PositionModel ,
32+ system .RoleModel ,
33+ system .OperationLogModel ,
34+ system .RoleDeptsModel ,
35+ system .RoleMenusModel ,
36+ system .UserPositionsModel ,
37+ system .UserRolesModel
38+ ]
39+
40+ def __get_db_url (self ) -> str :
2541 scheme = settings .SQL_DB_URL .scheme .split ('+' )[0 ]
2642 new_db_url = settings .SQL_DB_URL .unicode_string ().replace (settings .SQL_DB_URL .scheme , scheme )
2743 return new_db_url
2844
29- def __init_model (self ):
45+ def __init_model (self ) -> None :
3046 print ("开始初始化数据库..." )
3147 Model .metadata .create_all (
3248 self .engine ,
33- tables = [
34- system .DeptModel .__table__ ,
35- system .UserModel .__table__ ,
36- system .MenuModel .__table__ ,
37- system .PositionModel .__table__ ,
38- system .RoleModel .__table__ ,
39- system .OperationLogModel .__table__ ,
40- system .RoleDeptsModel .__table__ ,
41- system .RoleMenusModel .__table__ ,
42- system .UserPositionsModel .__table__ ,
43- system .UserRolesModel .__table__
44- ]
49+ tables = [modal .__table__ for modal in self .prepare_init_models ]
4550 )
4651 print ("数据库初始化完成!" )
4752
48- def __init_data (self ):
53+ def __init_data (self ) -> None :
4954 print ("开始初始化数据..." )
50- self .__init_dept ()
51- self .__init_user ()
52- self .__init_menu ()
53- self .__init_position ()
54- self .__init_role ()
55- self .__init_role_depts ()
56- self .__init_role_menus ()
57- self .__init_user_positions ()
58- self .__init_user_roles ()
55+
56+ for model in self .prepare_init_models :
57+ max_rows = self .__generate_data (model )
58+ self .__update_sequence (model , max_rows )
59+
5960 print ("数据初始化完成!" )
6061
61- def __generate_data (self , table_name : str , model : Model ) :
62+ def __generate_data (self , model : ModelType ) -> int :
6263 session = self .DBSession ()
6364
65+ table_name = model .__tablename__
66+
6467 data = self .__get_data (table_name )
6568 objs = [model (** item ) for item in data ]
6669 session .add_all (objs )
@@ -69,40 +72,41 @@ def __generate_data(self, table_name: str, model: Model):
6972 session .close ()
7073 print (f"{ table_name } 表数据已生成!" )
7174
72- def __get_data (self , filename : str ):
73- json_path = Path .joinpath (self .SCRIPT_DIR , 'data' , f'{ filename } .json' )
74- with open (json_path , 'r' , encoding = 'utf-8' ) as f :
75- data = orjson .loads (f .read ())
76- return data
75+ return len (objs )
7776
78- def __init_dept (self ):
79- self .__generate_data ("system_dept" , system .DeptModel )
77+ def __get_data (self , filename : str ) -> List [Dict ]:
78+ try :
79+ json_path = Path .joinpath (self .SCRIPT_DIR , 'data' , f'{ filename } .json' )
80+ with open (json_path , 'r' , encoding = 'utf-8' ) as f :
81+ data = orjson .loads (f .read ())
82+ return data
8083
81- def __init_menu ( self ) :
82- self . __generate_data ( "system_menu" , system . MenuModel )
84+ except FileNotFoundError :
85+ return []
8386
84- def __init_position (self ) :
85- self . __generate_data ( "system_position" , system . PositionModel )
87+ def __update_sequence (self , model : ModelType , max_rows : int ) -> None :
88+ table_name = model . __tablename__
8689
87- def __init_role (self ):
88- self .__generate_data ("system_role" , system .RoleModel )
90+ # 检查模型中是否有自增字段
91+ sequence_name = None
92+ for col in model .__table__ .columns :
93+ if col .autoincrement is True :
94+ sequence_name = f"{ table_name } _{ col .name } _seq"
95+ break
8996
90- def __init_role_depts (self ):
91- self .__generate_data ("system_role_depts" , system .RoleDeptsModel )
92-
93- def __init_role_menus (self ):
94- self .__generate_data ("system_role_menus" , system .RoleMenusModel )
95-
96- def __init_user (self ):
97- self .__generate_data ("system_user" , system .UserModel )
97+ if not sequence_name :
98+ print (f"{ table_name } 表无需设置自增序列值" )
99+ return
98100
99- def __init_user_positions (self ):
100- self .__generate_data ("system_user_positions" , system .UserPositionsModel )
101+ session = self .DBSession ()
101102
102- def __init_user_roles (self ):
103- self .__generate_data ("system_user_roles" , system .UserRolesModel )
103+ # 更新序列最大值
104+ new_value = max_rows + 1
105+ session .execute (text (f"ALTER SEQUENCE { sequence_name } RESTART WITH { new_value } " ))
106+ session .commit ()
107+ print (f"{ table_name } 表的自增序列值已更新!" )
104108
105- def run (self ):
109+ def run (self ) -> None :
106110 """
107111 执行初始化
108112 """
0 commit comments