282 lines
12 KiB
Python
282 lines
12 KiB
Python
|
||
import asyncio, logging, aiomysql
|
||
|
||
# 创建基本日志函数
|
||
def log(sql, args=()):
|
||
logging.info('SQL: %s' % sql)
|
||
|
||
# 创建连接池函数
|
||
async def create_pool(loop, **kw):
|
||
logging.info('create database connection pool...')
|
||
# 声明 __pool 为全局变量
|
||
global __pool
|
||
# 使用这些基本参数来创建连接池
|
||
# await 和 async 是联动的(异步IO)
|
||
# 连接池是一种标准技术,用于在内存中维护长时间运行的连接,以便有效地重用,
|
||
# 并未应用程序可能同时使用的连接总数提供管理。特别是对于服务器端Web应用程序,
|
||
# 连接池是内存中维护活动数据库连接池的标准方法,这些活动数据库连接在请求之间重复使用。
|
||
# 使用这些基本参数来创建连接池
|
||
# await 和 async 是联动的(异步IO)
|
||
__pool = await aiomysql.create_pool(
|
||
host=kw.get('host', 'localhost'),
|
||
port=kw.get('port', 3306),
|
||
user=kw['user'],
|
||
password=kw['password'],
|
||
db=kw['db'],
|
||
charset=kw.get('charset', 'utf8'),
|
||
autocommit=kw.get('autocommit', True),
|
||
maxsize=kw.get('maxsize', 10),
|
||
minsize=kw.get('minsize', 1),
|
||
loop=loop
|
||
)
|
||
|
||
async def select(sql, args, size=None):
|
||
log(sql, args)
|
||
global __pool
|
||
|
||
|
||
# 防止多个程序同时执行,达到异步效果
|
||
with (await __pool) as conn:
|
||
|
||
# 'aiomysql.DictCursor'要求返回字典格式
|
||
cur = await conn.cursor(aiomysql.DictCursor)
|
||
# cursor 游标实例可以调用 execute 来执行一条单独的 SQL 语句
|
||
await cur.execute(sql.replace('?', '%s'), args or())
|
||
# size 为空时为 False,上面定义了初始值为 None ,具体得看传入的参数有没有定义 size
|
||
if size:
|
||
# fetchmany 可以获取行数为 size 的多行查询结果集,返回一个列表
|
||
rs = await cur.fetchmany(size)
|
||
else:
|
||
# fetchall 可以获取一个查询结果的所有(剩余)行,返回一个列表
|
||
rs = await cur.fetchall()
|
||
# close() ,立即关闭 cursor ,从这一时刻起该 cursor 将不再可用
|
||
await cur.close()
|
||
# 日志:提示返回了多少行
|
||
logging.info('rows returned: %s' % len(rs))
|
||
# select 函数给我们从 SQL 返回了一个列表
|
||
return rs
|
||
|
||
# execute :执行
|
||
async def execute(sql, args):
|
||
log(sql)
|
||
global __pool
|
||
with (await __pool) as conn:
|
||
try:
|
||
cur = await conn.cursor()
|
||
await cur.execute(sql.replace('?', '%s'),args)
|
||
# rowcount 获取行数,应该表示的是该函数影响的行数
|
||
affected = cur.rowcount
|
||
await cur.close()
|
||
except BaseException as _:
|
||
# except BaseException as e:
|
||
# 将错误抛出,BaseEXception 是异常不用管
|
||
raise
|
||
# 返回行数
|
||
return affected
|
||
|
||
|
||
|
||
def create_args_string(num):
|
||
L = []
|
||
for _ in range(num):
|
||
L.append('?')
|
||
return ', '.join(L)
|
||
|
||
# Model 是一个基类,所以先定义 ModelMetaclass ,再在定义 Model 时使用 metaclass 参数
|
||
class ModelMetaclass(type):
|
||
# __new__()方法接收到的参数依次是:
|
||
# cls:当前准备创建的类的对象 class
|
||
# name:类的名字 str
|
||
# bases:类继承的父类集合 Tuple
|
||
# attrs:类的方法集合
|
||
def __new__(cls, name, bases, attrs):
|
||
# 排除 Model 类本身,返回它自己
|
||
if name=='Model':
|
||
return type.__new__(cls, name, bases, attrs)
|
||
# 获取 table 名称
|
||
tableName = attrs.get('__table__', None) or name
|
||
# 日志:找到名为 name 的 model
|
||
logging.info('found model: %s (table: %s)' % (name, tableName))
|
||
# 获取 所有的 Field 和主键名
|
||
mappings = dict()
|
||
fields = []
|
||
primaryKey = None
|
||
# attrs.items 取决于 __new__ 传入的 attrs 参数
|
||
for k,v in attrs.items():
|
||
# isinstance 函数:如果 v 和 Field 类型相同则返回 True ,不相同则 False
|
||
if isinstance(v, Field):
|
||
logging.info(' found mapping: %s ==> %s' % (k,v))
|
||
mappings[k] = v
|
||
# 这里的 v.primary_key 我理解为 :只要 primary_key 为 True 则这个 field 为主键
|
||
if v.primary_key:
|
||
# 找到主键,如果主键 primaryKey 有值时,返回一个错误
|
||
if primaryKey:
|
||
raise RuntimeError('Duplicate primary key for field: %s' % k)
|
||
# 然后直接给主键赋值
|
||
primaryKey = k
|
||
else:
|
||
# 没找到主键的话,直接在 fields 里加上 k
|
||
fields.append(k)
|
||
if not primaryKey:
|
||
# 如果主键为 None 就报错
|
||
raise RuntimeError('Primary key not found.')
|
||
for k in mappings.keys():
|
||
# pop :如果 key 存在于字典中则将其移除并返回其值,否则返回 default
|
||
attrs.pop(k)
|
||
|
||
escaped_fields = list(map(lambda f: '`%s`' % f, fields))
|
||
attrs['__mappings__'] = mappings # 保存属性和列的映射关系
|
||
attrs['__table__'] = tableName # table 名
|
||
attrs['__primary_key__'] = primaryKey # 主键属性名
|
||
attrs['__fields__'] = fields # 除主键外的属性名
|
||
# 构造默认的 SELECT, INSERT, UPDAT E和 DELETE 语句
|
||
attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName)
|
||
attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))
|
||
attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)
|
||
attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)
|
||
return type.__new__(cls, name, bases, attrs)
|
||
|
||
# metaclass 参数提示 Model 要通过上面的 __new__ 来创建
|
||
class Model(dict, metaclass=ModelMetaclass):
|
||
def __init__(self, **kw):
|
||
# super 用来引用父类? 引用了 ModelMetaclass ? super 文档:
|
||
super(Model, self).__init__(**kw)
|
||
# 返回参数为 key 的自身属性, 如果出错则报具体错误
|
||
def __getattr__(self, key):
|
||
try:
|
||
return self[key]
|
||
except KeyError:
|
||
raise AttributeError(r"'Model' object has no attribute '%s'" % key)
|
||
# 设置自身属性
|
||
def __setattr__(self, key, value):
|
||
self[key] = value
|
||
# 通过属性返回想要的值
|
||
def getValue(self, key):
|
||
return getattr(self, key, None)
|
||
#
|
||
def getValueOrDefault(self, key):
|
||
value = getattr(self, key, None)
|
||
if value is None:
|
||
# 如果 value 为 None,定位某个键; value 不为 None 就直接返回
|
||
field = self.__mappings__[key]
|
||
if field.default is not None:
|
||
# 如果 field.default 不是 None : 就把它赋值给 value
|
||
value = field.default() if callable(field.default) else field.default
|
||
logging.debug('using default value for %s: %s' % (key,str(value)))
|
||
setattr(self, key, value)
|
||
return value
|
||
|
||
# *** 往 Model 类添加 class 方法,就可以让所有子类调用 class 方法
|
||
@classmethod
|
||
async def findAll(cls, where=None, args=None, **kw):
|
||
## find objects by where clause
|
||
sql = [cls.__select__]
|
||
# where 默认值为 None
|
||
# 如果 where 有值就在 sql 加上字符串 'where' 和 变量 where
|
||
if where:
|
||
sql.append('where')
|
||
sql.append(where)
|
||
if args is None:
|
||
# args 默认值为 None
|
||
# 如果 findAll 函数未传入有效的 where 参数,则将 '[]' 传入 args
|
||
args = []
|
||
|
||
orderBy = kw.get('orderBy', None)
|
||
if orderBy:
|
||
# get 可以返回 orderBy 的值,如果失败就返回 None ,这样失败也不会出错
|
||
# oederBy 有值时给 sql 加上它,为空值时什么也不干
|
||
sql.append('order by')
|
||
sql.append(orderBy)
|
||
# 开头和上面 orderBy 类似
|
||
limit = kw.get('limit', None)
|
||
if limit is not None:
|
||
sql.append('limit')
|
||
if isinstance(limit, int):
|
||
# 如果 limit 为整数
|
||
sql.append('?')
|
||
args.append(limit)
|
||
elif isinstance(limit, tuple) and len(limit) == 2:
|
||
# 如果 limit 是元组且里面只有两个元素
|
||
sql.append('?, ?')
|
||
# extend 把 limit 加到末尾
|
||
args.extend(limit)
|
||
else:
|
||
raise ValueError('Invalid limit value: %s' % str(limit))
|
||
rs = await select(' '.join(sql), args)
|
||
# 返回选择的列表里的所有值 ,完成 findAll 函数
|
||
return [cls(**r) for r in rs]
|
||
|
||
@classmethod
|
||
async def findNumber(cls, selectField, where=None, args=None):
|
||
## find number by select and where
|
||
#找到选中的数及其位置
|
||
sql = ['select %s _num_ from `%s`' % (selectField, cls.__table__)]
|
||
if where:
|
||
sql.append('where')
|
||
sql.append(where)
|
||
rs = await select(' '.join(sql), args, 1)
|
||
if len(rs) == 0:
|
||
# 如果 rs 内无元素,返回 None ;有元素就返回某个数
|
||
return None
|
||
return rs[0]['_num_']
|
||
|
||
@classmethod
|
||
async def find(cls, pk):
|
||
## find object by primary key
|
||
# 通过主键找对象
|
||
rs = await select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1)
|
||
if len(rs) == 0:
|
||
return None
|
||
return cls(**rs[0])
|
||
|
||
# *** 往 Model 类添加实例方法,就可以让所有子类调用实例方法
|
||
async def save(self):
|
||
args = list(map(self.getValueOrDefault, self.__fields__))
|
||
args.append(self.getValueOrDefault(self.__primary_key__))
|
||
rows = await execute(self.__insert__, args)
|
||
if rows != 1:
|
||
logging.warning('failed to insert record: affected rows: %s' % rows)
|
||
|
||
async def update(self):
|
||
args = list(map(self.getValue, self.__fields__))
|
||
args.append(self.getValue(self.__primary_key__))
|
||
rows = await execute(self.__update__, args)
|
||
if rows != 1:
|
||
logging.warning('failed to update by primary key: affected rows: %s' % rows)
|
||
|
||
async def remove(self):
|
||
args = [self.getValue(self.__primary_key__)]
|
||
rows = await execute(self.__delete__, args)
|
||
if rows != 1:
|
||
logging.warning('failed to remove by primary key: affected rows: %s' % rows)
|
||
|
||
# 定义 Field
|
||
class Field(object):
|
||
def __init__(self, name, column_type, primary_key, default):
|
||
self.name = name
|
||
self.column_type = column_type
|
||
self.primary_key = primary_key
|
||
self.default = default
|
||
def __str__(self):
|
||
return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)
|
||
# 定义 Field 子类及其子类的默认值
|
||
class StringField(Field):
|
||
def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'):
|
||
super().__init__(name, ddl, primary_key, default)
|
||
|
||
class BooleanField(Field):
|
||
def __init__(self, name=None, default=False):
|
||
super().__init__(name, 'boolean', False, default)
|
||
|
||
class IntegerField(Field):
|
||
def __init__(self, name=None, primary_key=False, default=0):
|
||
super().__init__(name, 'bigint', primary_key, default)
|
||
|
||
class FloatField(Field):
|
||
def __init__(self, name=None, primary_key=False, default=0):
|
||
super().__init__(name, 'real', primary_key,default)
|
||
|
||
class TextField(Field):
|
||
def __init__(self, name=None, default=None):
|
||
super().__init__(name, 'text', False, default)
|