0%

优雅的Python库-records[数据库]

简介

  • records 是由Ken Reitz 大神(requests,pipenv,tablib,maya,request-html ) 开发的,一个符合人类习惯的,简单强大的用来执行原生SQL的库
  • 执行SQL ,并获取字典 列表结果

对比

psycopg2

1
2
3
4
5
6
7
8
9
10
11
import psycopg2
import psycopg2.extras

conn = psycopg2.connect(database="testdb", user="postgres", password="pass123", host="127.0.0.1", port=5432)
cur = conn.cursor(cursor_factory = psycopg2.extras.DictCursor)
fetch_all_as_dict = lambda cursor: [dict(row) for row in cursor]
cur.execute("SELECT * FROM test WHERE id=%s;",(1,))
# cur.fetchone()
print(fetch_all_as_dict(cur))
cur.close()
conn.close()

peewee

1
2
3
4
5
6
7
from peewee import *

db = SqliteDatabase('people.db')
cursor = db.execute_sql('select count(*) from test where id=%s', params=(1,))
fetch_all_as_dict = lambda cursor: [dict(zip(col_names, row)) for row in cursor.fetchall()]
print(fetch_all_as_dict(cursor))
cursor.close()

record

1
2
3
4
import records
db = records.Database("mysql://username:password@hostname:port/databasename")
rows = db.query("SELECT name,age FROM people WHERE name=:name",name="zhangsan")
print(rows.all(as_dict=True))

使用

初始化

1
2
3
4
import records

# 初始化db 连接 , 支持从 环境变量 DATABASE_URL 读取 url
db = records.Database("mysql://username:password@hostname:port/databasename")

查询

  • 查询
    • 默认RecordCollection 对象
    • 通过 as_dict=True 转化为字典
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
rows = db.query("SELECT name,age FROM people WHERE name=:name",name="zhangsan")

# 支持迭代遍历
for row in rows:
print(row.name, row.get("age"))

# Record 对象
>>> rows[0]
<Record {"name": "zhangsan", "age": 18}>
>>> rows[0].as_dict()
{'id': 1, 'name': 'zhangsan', 'age': 18}
>>> rows[0].name
'zhangsan'
>>> rows[0].get('age')
18

# 获取所有
rows.all()

# 字典
rows.as_dict()

# 查询唯一的一个
rows.one(as_dict=False)

# 获取第一个,字典形式
rows.first(as_dict=True)

# 执行sql文件
rows = db.query_file("a.sql")

操作

  • 操作
    • :variable 安全传参
    • bulk_query 批量操作
1
2
3
4
5
6
7
# 更新
user = {"name":"zhang", "age":18}
db.query("UPDATE people SET age=:age WHERE name=:name",**user)

## 批量添加
users = [{"name":"zhang", "age":18},{"name":"zhao", "age":18}]
db.bulk_query("INSERT INTO people(name,age) VALUES(:name, :age)",users)
  • 事务
1
2
3
4
with db.transaction() as tr:
# zhangsan 给lisi 转30
tr.query('UPDATE people set money=money-30 where name="zhangsan"')
tr.query('UPDATE people set money=money+30 where name="lisi" '

记录导出

1
2
3
4
5
6
7
8
9

# json
rows = db.query('SELECT * FROM people;')
json_rows = rows.export('json')
print(json_rows)

# xls 等二进制格式 需要直接写入文件
with open('users.xlsx', 'wb') as f:
f.write(rows.export('xlsx'))

其他

  • 提供cli:records <query> [<format>] [<params>...] [--url=<url>]
1
2
# 导出 sql数据到 out.csv
records 'sql查询语句" ' xls --url='数据库url' > out.csv

代码

records-class

Record 单条记录

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class Record(object):
""" 单条记录 """
# 限制实例属性,减少内存占用
__slots__ = ('_keys', '_values')

def __init__(self, keys, values):
self._keys = keys
self._values = values
# 断言key values 匹配
assert len(self._keys) == len(self._values)

def keys(self):
""" 对外提供keys 方法 隐藏 _keys属性 """
return self._keys

def __getitem__(self, key):
# 支持索引操作 r[1]
if isinstance(key, int):
return self.values()[key]
# 支持[]操作 r['key']
if key in self.keys():
# 异常边界,不允许 重复的key
i = self.keys().index(key)
if self.keys().count(key) > 1:
raise KeyError("Record contains multiple '{}' fields.".format(key))
return self.values()[i]

raise KeyError("Record contains no '{}' field.".format(key))

def __getattr__(self, key):
# 添加属性 r.key
try:
return self[key]
except KeyError as e:
raise AttributeError(e)

def __dir__(self):
# 列出所有 属性和方法名
standard = dir(super(Record, self))
# 合并标准属性和动态生成的属性
return sorted(standard + [str(k) for k in self.keys()])

def get(self, key, default=None):
""" 返回 指定k的值 """
try:
return self[key]
except KeyError:
return default

def as_dict(self, ordered=False):
"""返回记录作为 字典 /有序字典 (Python3.6 + 官方字典已经是有序 ) """
items = zip(self.keys(), self.values())
return OrderedDict(items) if ordered else dict(items)

@property #装饰器, 把getter方法变成属性,(可以对数据进行限制,避免随便赋值)
def dataset(self):
""" 数据集配置"""
data = tablib.Dataset()
data.headers = self.keys()
# 转化datetime 对象为标准时间字符串
row = _reduce_datetimes(self.values())
data.append(row)
return data

def export(self, format, **kwargs):
# 导出为 xlsx,json, csv ... 依赖 库 tablib 实现
return self.dataset.export(format, **kwargs)

RecordCollection 记录集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class RecordCollection(object):
""" Record 记录集合 """
def __init__(self, rows):
self._rows = rows
self._all_rows = []
self.pending = True

def __iter__(self):
"""迭代器方法 大量数据时提升性能,节省内存 """
i = 0
while True:
# 其他方法可能迭代过,先检查缓存,提升性能
if i < len(self):
yield self[i]
else:
# following https://www.python.org/dev/peps/pep-0479/
try:
yield next(self)
except StopIteration:
return
i += 1

def __next__(self):
try:
nextrow = next(self._rows)
self._all_rows.append(nextrow)
return nextrow
except StopIteration:
self.pending = False
raise StopIteration('RecordCollection contains no more rows.')

def __getitem__(self, key):
""" 支持索引操作 并缓存已经取过的 记录 """
is_int = isinstance(key, int)
# Convert RecordCollection[1] into slice.
if is_int:
key = slice(key, key + 1)

while len(self) < key.stop or key.stop is None:
try:
next(self)
except StopIteration:
break
rows = self._all_rows[key]
if is_int:
return rows[0]
else:
return RecordCollection(iter(rows))

def __len__(self):
return len(self._all_rows)

def all(self, as_dict=False, as_ordereddict=False):
""" 返回所有Record列表 """
# By calling list it calls the __iter__ method
rows = list(self)
if as_dict:
return [r.as_dict() for r in rows]
# ...
return rows

def as_dict(self, ordered=False):
""" 返回字典列表 """
return self.all(as_dict=not(ordered), as_ordereddict=ordered)

def first(self, default=None, as_dict=False, as_ordereddict=False):
"""返回第一条记录, 返回default /或抛出default """

try:
record = self[0]
except IndexError:
if isexception(default):
raise default
return default

if as_dict:
return record.as_dict()
elif as_ordereddict:
return record.as_dict(ordered=True)
else:
return record

def one(self, default=None, as_dict=False, as_ordereddict=False):
"""返回唯一记录, 如果多个会抛出异常 """
try:
self[1]
except IndexError:
return self.first(default=default, as_dict=as_dict, as_ordereddict=as_ordereddict)
else:
raise ValueError('RecordCollection contained more than one row. '
'Expects only one row when using '
'RecordCollection.one')


DataBase 数据库

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

class Database(object):
"""数据库对象, 封装了 SQLAlchemy 连接池 和 DB_URL. """
def __init__(self, db_url=None, **kwargs):
self.db_url = db_url or os.environ.get('DATABASE_URL')
if not self.db_url:
raise ValueError('You must provide a db_url.')
# Create an engine.
self._engine = create_engine(self.db_url, **kwargs)
self.open = True

def close(self):
"""Closes the Database."""
self._engine.dispose()
self.open = False

def get_table_names(self, internal=False):
"""返回 连接的数据库里所有表列表"""
return inspect(self._engine).get_table_names()

def get_connection(self):
"""获取数据库连接 """
if not self.open:
raise exc.ResourceClosedError('Database closed.')
return Connection(self._engine.connect())

def query(self, query, fetchall=False, **params):
"""执行sql查询, 默认迭代方式返回,参数可选, 可将结果作为字典 as_dict
"""
with self.get_connection() as conn:
return conn.query(query, fetchall, **params)

def bulk_query(self, query, *multiparams):
""" 批量更新/插入 """
with self.get_connection() as conn:
conn.bulk_query(query, *multiparams)

@contextmanager
def transaction(self):
""" 事务上下文管理方法 """
conn = self.get_connection()
tx = conn.transaction()
try:
yield conn
tx.commit()
except:
tx.rollback()
finally:
conn.close()

Connection 连接

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

class Connection(object):
""" 数据库连接 """
# ...
def __enter__(self):
return self

def __exit__(self, exc, val, traceback):
self.close()

def query(self, query, fetchall=False, **params):
"""执行查询 """
# 执行指定sql
cursor = self._conn.execute(text(query), **params)
# Row-by-row Record generator.
row_gen = (Record(cursor.keys(), row) for row in cursor)
# Convert psycopg2 results to RecordCollection.
results = RecordCollection(row_gen)
# ...
return results

def bulk_query(self, query, *multiparams):
""" 批量插入或更新"""
self._conn.execute(text(query), *multiparams)

def query_file(self, path, fetchall=False, **params):
"""从文件里读取sql 进行查询"""
# If path doesn't exists
if not os.path.exists(path):
raise IOError("File '{}' not found!".format(path))
# If it's a directory
if os.path.isdir(path):
raise IOError("'{}' is a directory!".format(path))
# Read the given .sql file into memory.
with open(path) as f:
query = f.read()
# Defer processing to self.query method.
return self.query(query=query, fetchall=fetchall, **params)

其他

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def isexception(obj):
"""判断 是 异常对象 """
if isinstance(obj, Exception):
return True
if isclass(obj) and issubclass(obj, Exception):
return True
return False

def _reduce_datetimes(row):
"""转化datetime 对象到isoformat 字符串 """
row = list(row)
for i in range(len(row)):
if hasattr(row[i], 'isoformat'):
row[i] = row[i].isoformat()
return tuple(row)

def cli():
# ... 命令行方法支持

欢迎关注我的其它发布渠道