Source code for bi_etl.database.database_metadata

# -*- coding: utf-8 -*-
"""
Created on Dec 23, 2015

@author: Derek Wood
"""
import logging
import textwrap

import sqlalchemy
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import Session
from sqlalchemy.sql.schema import DEFAULT_NAMING_CONVENTION

from bi_etl.utility.case_insentive_set import CaseInsentiveSet


[docs] class DatabaseMetadata(sqlalchemy.schema.MetaData): """ A light wrapper over :class:`sqlalchemy.schema.MetaData` """
[docs] def __init__(self, bind=None, reflect=False, schema=None, quote_schema=None, naming_convention=DEFAULT_NAMING_CONVENTION, info=None, database_name=None, uses_bytes_length_limits=None, ): super().__init__( schema=schema, quote_schema=quote_schema, naming_convention=naming_convention, info=info, ) # Save parameters not saved by the base class for use in __reduce_ex__ self.bind = bind self._save_reflect = reflect self._save_quote_schema = quote_schema self._table_inventory = None self.database_name = database_name self._uses_bytes_length_limits = uses_bytes_length_limits self._connection_pool = dict() self._transactions = dict() self.default_connection_name = 'default' self.log = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
def __reduce_ex__(self, protocol): return ( # A callable object that will be called to create the initial version of the object. self.__class__, # A tuple of arguments for the callable object. An empty tuple must be given if the callable does not accept any argument (self.bind.url, self._save_reflect, self.schema, self._save_quote_schema, self.naming_convention, self.info, self.database_name, self._uses_bytes_length_limits), # Optionally, the object’s state, which will be passed to the object’s __setstate__() method as previously described. # If the object has no such method then, the value must be a dictionary and it will be added to the object’s __dict__ attribute. None, # Optionally, an iterator (and not a sequence) yielding successive items. # These items will be appended to the object either using obj.append(item) or, in batch, using obj.extend(list_of_items). # Optionally, an iterator (not a sequence) yielding successive key-value pairs. # These items will be stored to the object using obj[key] = value # PROTOCOL 5+ only # Optionally, a callable with a (obj, state) signature. # This callable allows the user to programmatically control the state-updating behavior of a specific object, # instead of using obj’s static __setstate__() method. # If not None, this callable will have priority over obj’s __setstate__(). ) def _set_parent(self, parent, **kwargs): pass
[docs] def resolve_connection_name(self, connection_name: str = None) -> str: if connection_name is None: connection_name = self.default_connection_name # When using sqlite don't make new connections, reuse the existing one if self.dialect_name == 'sqlite': connection_name = 'sqlite' return connection_name
def _connect(self) -> sqlalchemy.engine.base.Connection: self.log.debug(f"Connecting connection {self.bind}") return self.bind.connect()
[docs] def connection( self, connection_name: str = None, open_if_not_exist: bool = True, open_if_closed: bool = True, ) -> sqlalchemy.engine.base.Connection: connection_name = self.resolve_connection_name(connection_name) connection_key = (connection_name,) if connection_key in self._connection_pool: con = self._connection_pool[connection_key] if con.closed and open_if_closed: con = self._connect() self._connection_pool[connection_key] = con else: if open_if_not_exist: con = self._connect() self._connection_pool[connection_key] = con else: raise ValueError(f"Connection {connection_name} does not exist, and open_if_not_exist = False") return con
[docs] def connect( self, connection_name: str = None, ) -> sqlalchemy.engine.base.Connection: return self.connection( connection_name, open_if_not_exist=True, open_if_closed=True, )
[docs] def is_connected(self, connection_name: str = None) -> bool: try: con = self.connection(connection_name, open_if_not_exist=False, open_if_closed=False) return con.closed except ValueError: return False
[docs] def close_connection(self, connection_name: str = None): try: con = self.connection(connection_name, open_if_not_exist=False, open_if_closed=False) con.close() except ValueError: pass
[docs] def close_connections(self, exceptions: set = None): if exceptions is None: exceptions = set() for connection_key, con in self._connection_pool.items(): connection_name = connection_key[0] if connection_name not in exceptions: self.log.debug(f'Closing connection {self} {connection_name}') con.close()
[docs] def dispose(self): """ This method leaves the possibility of checked-out connections remaining open, as it only affects connections that are idle in the pool. """ self.close_connections() self.bind.pool.dispose()
[docs] def session(self): return Session(bind=self.bind)
def _begin(self, connection_name: str) -> sqlalchemy.engine.base.Transaction: tx = self.connection(connection_name=connection_name).begin() self._transactions[connection_name] = tx return tx
[docs] def begin(self, connection_name: str = None) -> sqlalchemy.engine.base.Transaction: connection_name = self.resolve_connection_name(connection_name) if connection_name not in self._transactions: tx = self._begin(connection_name) else: tx = self._transactions[connection_name] if not tx.is_active: tx = self._begin(connection_name) return tx
[docs] def has_active_transaction(self, connection_name: str = None): connection_name = self.resolve_connection_name(connection_name) if connection_name not in self._transactions: return False else: tx = self._transactions[connection_name] return tx.is_active
[docs] def commit(self, connection_name: str = None): """ Commit based on a connection name rather than via a 'sqlalchemy.engine.base.Transaction' object (which you could call .commit() on Parameters ---------- connection_name """ connection_name = self.resolve_connection_name(connection_name) if connection_name not in self._transactions: self.log.debug(f"Commit: There is no transaction recorded for {self} {connection_name}") else: tx = self._transactions[connection_name] if tx.is_active: self.log.info(f'Commit on {self} {connection_name} connection started') tx.commit() self.log.debug(f'Commit on {self} {connection_name} connection done') else: self.log.info(f'Connection {self} {connection_name} transaction not active (commit called)')
[docs] def rollback(self, connection_name: str = None): connection_name = self.resolve_connection_name(connection_name) if connection_name not in self._transactions: raise RuntimeError(f"rollback: There is no transaction recorded for {self} {connection_name}") else: tx = self._transactions[connection_name] if tx.is_active: tx.rollback() self.log.info(f'Rollback on {self} {connection_name} connection done') else: raise RuntimeError(f'Connection {self} {connection_name} transaction not active (rollback called)')
[docs] def execute( self, sql, *list_params, transaction: bool = True, auto_close: bool = True, connection_name: str = None, **params ): connection = None try: connection = self.connect( connection_name=connection_name, ) if isinstance(sql, str): sql = sqlalchemy.text(sql) if transaction or not connection.in_transaction(): # Equivalent to Autocommit with connection.begin(): result = connection.execute(sql, *list_params, **params) else: result = connection.execute(sql, *list_params, **params) return result finally: if auto_close: if connection is not None: connection.close()
[docs] def execute_procedure( self, procedure_name, *args, return_results=False, dpapi_connection=None ): """ Execute a stored procedure Parameters ---------- procedure_name: str The procedure to run. args: The arguments to pass return_results: Needs to be a keyword param. Should we try and get result rows from the procedure. dpapi_connection: A raw dpapi connection to use. Optional. Raises ------ sqlalchemy.exc.DBAPIError: API error sqlalchemy.exc.DatabaseError: Proxy for database error """ log = logging.getLogger(__name__) log.debug(f"Calling procedure {procedure_name} {args}") if dpapi_connection is None: dpapi_connection = self.bind.raw_connection() close_connection = True else: close_connection = False results = None try: cursor = dpapi_connection.cursor() if hasattr(cursor, 'callproc'): cursor.callproc(procedure_name, args) if return_results: results = list(cursor.fetchall()) cursor.close() else: # Stopped using CALL because of issues like those mentioned on https://stackoverflow.com/a/34179375 # if False: # 'pyodbc' in self.bind.dialect.dialect_description == 'mssql+pyodbc': # if len(args) > 0: # sql = f"{{CALL {procedure_name}({','.join([qmark for qmark in ['?'] * len(args)])}) }}" # else: # sql = f"{{CALL {procedure_name}}}" # else: # sql = f"EXEC {procedure_name} {','.join([qmark for qmark in ['?'] * len(args)])}" sql = f"EXEC {procedure_name} " args2 = [] delim = '' for arg in args: if isinstance(arg, str): arg = arg.strip() # Handle keyword named parameters if arg[0] == '@': param, value = arg.split('=') param = param.strip() param = f'{param}=?' value = value.strip() # Likely the opening quote of the value has not been removed yet if value[0] == "'": value = value[1:] else: param = '?' value = arg sql += delim + param delim = ', ' args2.append(value) cursor.execute(sql, args2) if return_results: results = list(cursor.fetchall()) cursor.close() dpapi_connection.commit() finally: if close_connection: dpapi_connection.close() return results
[docs] def execute_direct( self, sql, return_results=False ): log = logging.getLogger(__name__) log.debug(sql) dpapi_connection = self.bind.raw_connection() try: cursor = dpapi_connection.cursor() cursor.execute(sql) results = None if return_results: results = list(cursor.fetchall()) cursor.close() dpapi_connection.commit() finally: dpapi_connection.close() return results
[docs] def table_inventory(self, schema=None, force_reload=False): if self._table_inventory is None: self._table_inventory = dict() if schema not in self._table_inventory or force_reload: try: from sqlalchemy import inspect except ImportError: inspect = Inspector.from_engine inspector = inspect(self.bind) self._table_inventory[schema] = CaseInsentiveSet(inspector.get_table_names(schema=schema)) return self._table_inventory[schema]
[docs] @staticmethod def qualified_name(schema, table): if schema is not None: return schema + '.' + table else: return table
[docs] def rename_table(self, schema, table_name, new_table_name): if self.dialect_name == 'mssql': self.execute_procedure( 'sp_rename', self.qualified_name(schema, table_name), new_table_name ) else: sql = f"alter table {self.qualified_name(schema, table_name)} rename to {new_table_name}" self.log.debug(sql) self.execute(sql)
[docs] def drop_table_if_exists( self, table_name, schema=None, connection_name: str = None, transaction: bool = False, auto_close: bool = False, ): if schema is None: if '.' in table_name: schema, table_name = table_name.split('.') # SQL Server 2016+ can use IF EXISTS but rather than checking version use compatible mode if self.dialect_name == 'mssql': if table_name[0] == '#': # Temp table sql = textwrap.dedent(f"""\ IF OBJECT_ID('tempdb.dbo.{table_name}', 'U') IS NOT NULL DROP TABLE {self.qualified_name(schema, table_name)}; """) else: sql = textwrap.dedent(f"""\ IF OBJECT_ID('{self.qualified_name(schema, table_name)}', 'U') IS NOT NULL DROP TABLE {self.qualified_name(schema, table_name)}; """) elif self.dialect_name == 'oracle': sql = textwrap.dedent(f"""\ BEGIN EXECUTE IMMEDIATE 'DROP TABLE {table_name}'; EXCEPTION WHEN OTHERS THEN IF SQLCODE != -942 THEN RAISE; END IF; END; """) else: sql = f"drop table IF EXISTS {self.qualified_name(schema, table_name)}" self.log.debug(sql) self.execute( sql, transaction=transaction, auto_close=auto_close, connection_name=connection_name, )
@property def dialect(self): return self.bind.dialect @property def dialect_name(self): return self.bind.dialect.name @property def uses_bytes_length_limits(self): if self._uses_bytes_length_limits is None: # Note: Oracle can use either VARCHAR2(10 CHAR) or VARCHAR2(10 BYTE) # However, if not specified (and NLS_LENGTH_SEMANTICS is default), it's char so we assume that. if self.dialect_name in { 'redshift', 'oracle' }: self._uses_bytes_length_limits = True else: self._uses_bytes_length_limits = False return self._uses_bytes_length_limits