observability.svc.plus/bin/psql.py
2026-02-01 20:53:55 +08:00

222 lines
6.8 KiB
Python
Executable File

#!/usr/bin/env python3
# -*- coding: utf-8 -*- #
# License : Apache-2.0 @ https://pigsty.io/docs/about/license/
# Copyright : 2018-2026 Ruohang Feng / Vonng (rh@vonng.com)
__author__ = 'Vonng (rh@vonng.com)'
# A Wrapper for psycopg2
import psycopg2
import psycopg2.extras
from psycopg2.extras import Json
DEFAULT_BUF_SIZE = 2000
# PSQL is a convenient wrapper for psycopg2
# Usage:
# from psql import *
# pg = PSQL('service=meta')
# c.list()
# c.execute('SELECT 1')
class PSQL(object):
def __init__(self, url=''):
self.url = url
self.conn = psycopg2.connect(url)
def reconnect(self):
self.conn = psycopg2.connect(self.url)
def execute(self, sql, data=None):
with self.conn.cursor() as cursor:
try:
cursor.execute(sql, vars=data)
self.conn.commit()
return cursor.rowcount
except Exception:
self.conn.rollback()
raise
def execute_many(self, sql, data):
with self.conn.cursor() as cursor:
try:
cursor.executemany(sql, vars_list=data)
self.conn.commit()
return cursor.rowcount
except:
self.conn.rollback()
raise
def execute_mono(self, sql, data_seq, skip_error=True):
'''
Execute sql for each record one by one.
'''
total_num = 0
affect_num = 0
with self.conn.cursor() as cursor:
for record in data_seq:
total_num += 1
try:
cursor.execute(sql, record)
self.conn.commit()
affect_num += 1
except:
self.conn.rollback()
if not skip_error:
raise
# rollback and skip
return total_num, affect_num
def iexecute(self, sql, idata=None, buffer=True, buf_sz=DEFAULT_BUF_SIZE, skip_error=False):
if not buffer:
return self.execute_mono(sql, idata, skip_error)
buf = []
total_num = 0
affect_num = 0
for record in idata:
buf.append(record)
total_num += 1
if total_num % buf_sz == 0:
try:
affect_num += self.execute_many(sql, buf)
except:
if not skip_error: raise
finally:
buf = []
else:
# commit remain data
if len(buf) > 0:
try:
affect_num += self.execute_many(sql, buf)
except:
if not skip_error: raise
return total_num, affect_num
def fetch(self, sql, data=None):
with self.conn.cursor() as cursor:
cursor.execute(sql, vars=data)
return cursor.fetchall()
def fetch_one(self, sql, data=None):
with self.conn.cursor() as cursor:
cursor.execute(sql)
return cursor.fetchone()
def fetch_scale(self, sql, data=None):
with self.conn.cursor() as cursor:
cursor.execute(sql, data)
res = cursor.fetchone()
return res[0] if res and len(res) > 0 else None
def fetch_column(self, sql, data=None):
with self.conn.cursor() as cursor:
cursor.execute(sql)
return [item[0] for item in cursor.fetchall()]
def ifetch(self, sql, name=None, buf_size=DEFAULT_BUF_SIZE):
'''
Suit for large bulk selection. Use fetchmany & ServerSide cursor
'''
# if name is None:
# name = "ifetch_%s" % (int(time.time()))
with self.conn.cursor(name=name) as cursor:
cursor.arraysize = buf_size
cursor.execute(sql)
while True:
buf = cursor.fetchmany()
n_records = len(buf)
if n_records == 0:
return
for record in buf:
yield record
def call(self, func, args):
with self.conn.cursor() as cursor:
cursor.callproc(func, args)
return cursor.fetchall()
def count(self, table, condition=None):
if condition and condition != '':
sql = "SELECT count(*) FROM {0} WHERE {1};".format(table, condition)
else:
sql = "SELECT count(*) FROM {0};".format(table)
return self.fetch_scale(sql)
def glimpse(self, table, limit=None):
'''
Inspect a table by sampling some data
'''
if limit and int(limit) > 0:
sql = "SELECT * FROM {0} LIMIT {1};".format(table, limit)
else:
sql = "SELECT * FROM {0};".format(table)
return self.fetch(sql)
def desc(self, table, schema='public'):
sql = "SELECT ordinal_position,column_name,data_type FROM information_schema.columns " \
"WHERE table_schema= '{0}' and table_name = '{1}' ORDER BY ordinal_position;".format(
schema, table
)
return self.fetch(sql)
def drop(self, table):
'''
Drop a specfic table by name
'''
sql = "DROP TABLE IF EXISTS {0};".format(table)
self.execute(sql)
def truncate(self, table):
'''
Delete a specific table by table name
'''
sql = "DELETE FROM {0};".format(table)
self.execute(sql)
def rename(self, table, new_name):
'''
Delete a specific table by table name
'''
sql = "ALTER TABLE {0} RENAME TO {1};".format(table, new_name)
self.execute(sql)
def list(self, schema='public'):
sql = "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname = '{0}'".format(schema)
return self.fetch_column(sql)
def insert(self, table, columns, value):
'''
Insert single record
'''
col_spec = ','.join(columns)
val_spec = ','.join(['%s'] * len(columns))
sql = self.insert_sql(table, columns)
return self.execute(sql, value)
def insert_json(self, table, json_data):
'''
Insert single json record. Only for convenient use.
'''
columns = []
values = []
for column, value in json_data.iteritems():
columns.append(column)
values.append(value)
sql = self.insert_sql(table, columns)
return self.execute(sql, values)
@staticmethod
def insert_sql(table, columns, do_nothing=False):
col_spec = ','.join(columns)
val_spec = ','.join(['%s'] * len(columns))
sql = 'INSERT INTO {0} ({1}) VALUES ({2})'.format(table, col_spec, val_spec)
if do_nothing:
sql = 'INSERT INTO {0} ({1}) VALUES ({2}) ON CONFLICT DO NOTHING;'.format(table, col_spec, val_spec)
return sql