Source code for bi_etl.components.table

"""
Created on Sep 17, 2014

@author: Derek Wood
"""
# https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

import codecs
import contextlib
import dataclasses
import math
import sys
import textwrap
import traceback
import warnings
from datetime import datetime, date, time, timedelta
from decimal import Decimal
from decimal import InvalidOperation
from typing import *
from typing import Iterable, Callable, List, Union

import sqlalchemy
from gevent import spawn, sleep
from gevent.queue import Queue
from sqlalchemy import CHAR
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql.dml import UpdateBase
from sqlalchemy.sql.expression import bindparam
from sqlalchemy.sql.type_api import TypeEngine

import bi_etl
from bi_etl.bulk_loaders.bulk_loader import BulkLoader
from bi_etl.components.etlcomponent import ETLComponent
from bi_etl.components.get_next_key.local_table_memory import LocalTableMemory
from bi_etl.components.readonlytable import ReadOnlyTable
from bi_etl.components.row.column_difference import ColumnDifference
from bi_etl.components.row.row import Row
from bi_etl.components.row.row_status import RowStatus
from bi_etl.conversions import nvl, int2base
from bi_etl.conversions import replace_tilda
from bi_etl.conversions import round_datetime_ms
from bi_etl.conversions import str2date
from bi_etl.conversions import str2datetime
from bi_etl.conversions import str2decimal
from bi_etl.conversions import str2float
from bi_etl.conversions import str2int
from bi_etl.conversions import str2time
from bi_etl.database import DatabaseMetadata
from bi_etl.exceptions import NoResultFound
from bi_etl.lookups.lookup import Lookup
from bi_etl.scheduler.task import ETLTask
from bi_etl.statement_queue import StatementQueue
from bi_etl.statistics import Statistics
from bi_etl.timer import Timer
from bi_etl.utility import dict_to_str
from bi_etl.utility import get_integer_places


[docs] @dataclasses.dataclass class PendingUpdate: update_where_keys: Tuple[str] update_where_values: Iterable[Any] update_row_new_values: Row
# noinspection SqlDialectInspection
[docs] class Table(ReadOnlyTable): """ A class for accessing and updating a table. Parameters ---------- task : ETLTask The instance to register in (if not None) database : bi_etl.scheduler.task.Database The database to find the table/view in. table_name : str The name of the table/view. exclude_columns Optional. A list of columns to exclude from the table/view. These columns will not be included in SELECT, INSERT, or UPDATE statements. Attributes ---------- auto_generate_key: boolean Should the primary key be automatically generated by the insert/upsert process? If True, the process will get the current maximum value and then increment it with each insert. batch_size: int How many rows should be inserted / updated / deleted in a single batch. Default = 5000. Assigning None will use the default. delete_flag : str The name of the delete_flag column, if any. (inherited from ReadOnlyTable) delete_flag_yes : str, optional The value of delete_flag for deleted rows. (inherited from ReadOnlyTable) delete_flag_no : str, optional The value of delete_flag for *not* deleted rows. (inherited from ReadOnlyTable) default_date_format: str The date parsing format to use for str -> date conversions. If more than one date format exists in the source, then explicit conversions will be required. Default = '%m/%d/%Y' default_date_time_format: Union[str, Iterable[str]] The date+time parsing format to use for str -> date time conversions. If more than one date format exists in the source, then explicit conversions will be required. Default = '%m/%d/%Y %H:%M:%S' default_time_format: str The time parsing format to use for str ->time conversions. If more than one date format exists in the source, then explicit conversions will be required. Default = '%H:%M:%S' force_ascii: boolean Should text values be forced into the ascii character set before passing to the database? Default = False last_update_date: str Name of the column which we should update when table updates are made. Default = None log_first_row : boolean Should we log progress on the first row read. *Only applies if used as a source.* (inherited from ETLComponent) max_rows : int, optional The maximum number of rows to read. *Only applies if Table is used as a source.* (inherited from ETLComponent) primary_key The name of the primary key column(s). Only impacts trace messages. Default=None. If not passed in, will use the database value, if any. (inherited from ETLComponent) progress_frequency : int, optional How often (in seconds) to output progress messages. (inherited from ETLComponent) progress_message : str, optional The progress message to print. Default is ``"{logical_name} row # {row_number}"``. Note ``logical_name`` and ``row_number`` substitutions applied via :func:`format`. (inherited from ETLComponent) special_values_descriptive_columns A list of columns that get longer descriptive text in :meth:`get_missing_row`, :meth:`get_invalid_row`, :meth:`get_not_applicable_row`, :meth:`get_various_row` (inherited from ReadOnlyTable) track_source_rows: boolean Should the :meth:`upsert` method keep a set container of source row keys that it has processed? That set would then be used by :meth:`update_not_processed`, :meth:`logically_delete_not_processed`, and :meth:`delete_not_processed`. """ DEFAULT_BATCH_SIZE = 5000 # Replacement for float Not a Number (NaN) values NAN_REPLACEMENT_VALUE = None from enum import IntEnum, unique
[docs] @unique class InsertMethod(IntEnum): execute_many = 1 insert_values_list = 2 bulk_load = 3
[docs] @unique class UpdateMethod(IntEnum): execute_many = 1 bulk_load = 3
[docs] @unique class DeleteMethod(IntEnum): execute_many = 1 bulk_load = 3
[docs] def __init__( self, task: Optional[ETLTask], database: DatabaseMetadata, table_name: str, table_name_case_sensitive: bool = True, schema: Optional[str] = None, exclude_columns: Optional[set] = None, **kwargs ): # Don't pass kwargs up. They should be set here at the end super(Table, self).__init__( task=task, database=database, table_name=table_name, table_name_case_sensitive=table_name_case_sensitive, schema=schema, exclude_columns=exclude_columns, ) self.support_multiprocessing = False self._special_row_header = None self._insert_method = Table.InsertMethod.execute_many self._update_method = Table.UpdateMethod.execute_many self._delete_method = Table.DeleteMethod.execute_many self.bulk_loader = None self._bulk_rollback = False self._bulk_load_performed = False self._bulk_iter_sentinel = None self._bulk_iter_queue = None self._bulk_iter_worker = None self._bulk_iter_max_q_length = 0 self._bulk_defaulted_columns = set() self.track_update_columns = True self.track_source_rows = False self.auto_generate_key = False self.upsert_called = False self.last_update_date = None self.use_utc_times = False self.default_date_format = '%Y-%m-%d' self.default_date_time_format = ( '%Y-%m-%d %H:%M:%S', '%Y-%m-%d', '%m/%d/%Y %H:%M:%S', '%m/%d/%Y', ) self.default_time_format = '%H:%M:%S' self.default_long_text = 'Not Available' self.default_medium_text = 'N/A' self.default_char1_text = '?' self.force_ascii = False codecs.register_error('replace_tilda', replace_tilda) self.__batch_size = self.DEFAULT_BATCH_SIZE self.__transaction_pool = dict() self.skip_coercion_on = {} self._logical_delete_update = None # Safe type mode is slower, but gives better error messages than the database # that will likely give a not-so helpful message or silently truncate a value. self.safe_type_mode = True # Init table "memory" self.max_keys = dict() self.max_locally_allocated_keys = dict() self._table_key_memory = LocalTableMemory(self) self._max_keys_lock = contextlib.nullcontext() # Make a self example source row self._example_row = self.Row() if self.columns is not None: for c in self.columns: self._example_row[c] = None self.source_keys_processed = set() self._bind_name_map = None self.insert_hint = None # A list of any pending rows to be inserted self.pending_insert_stats = None self.pending_insert_rows = list() # A StatementQueue of any pending delete statements, the queue has the # statement itself (based on the keys), and a list of pending values self.pending_delete_stats = None self.pending_delete_statements = StatementQueue() # A list of any pending rows to apply as updates self.pending_update_stats = None self.pending_update_rows: Dict[tuple, PendingUpdate] = dict() self._coerce_methods_built = False # Should be the after all self attributes are created self.set_kwattrs(**kwargs)
[docs] def close(self, error: bool = False): if not self.is_closed: if self.in_bulk_mode: if self._bulk_load_performed: self.log.debug( f"{self}.close() NOT causing bulk load since a bulk load has already been performed." ) elif error: self.log.info(f"{self}.close() NOT causing bulk load due to error.") else: self.bulk_load_from_cache() else: self._insert_pending_batch() # Call self.commit() if any connection has has_active_transaction # but leave it up to commit() to determine the order any_active_transactions = False for connection_name in self._connections_used: if self.database.has_active_transaction(connection_name): any_active_transactions = True if any_active_transactions: self.commit() super(Table, self).close(error=error) self.clear_cache()
def __iter__(self) -> Iterable[Row]: # Note: yield_per will break if the transaction is committed while we are looping for row in self.where(None): yield row @property def batch_size(self): return self.__batch_size @batch_size.setter def batch_size(self, batch_size): if batch_size is not None: if not self.in_bulk_mode: if batch_size > 0: self.__batch_size = batch_size else: self.__batch_size = 1 @property def table_key_memory(self): return self._table_key_memory @table_key_memory.setter def table_key_memory(self, table_key_memory): self._table_key_memory = table_key_memory @property def in_bulk_mode(self): return self._insert_method == Table.InsertMethod.bulk_load
[docs] def set_bulk_loader( self, bulk_loader: BulkLoader ): self.log.info(f'Changing {self} to bulk load method') self.__batch_size = sys.maxsize self._insert_method = Table.InsertMethod.bulk_load self._update_method = Table.UpdateMethod.bulk_load self._delete_method = Table.DeleteMethod.bulk_load self.bulk_loader = bulk_loader
[docs] def cache_row( self, row: Row, allow_update: bool = False, allow_insert: bool = True, ): if self.in_bulk_mode: self._bulk_load_performed = False if self.bulk_loader.needs_all_columns: if row.column_set != self.column_names_set: # Add missing columns setting to default value for column_name in self.column_names_set - row.column_set: column = self.get_column(column_name) default = column.default if not column.nullable and default is None: if column.type.python_type == str: col_len = column.type.length if col_len > 4 and self.database.uses_bytes_length_limits: col_len = int(col_len / 4) if col_len is None: col_len = 4000 if col_len >= len(self.default_long_text): default = self.default_long_text elif col_len >= len(self.default_medium_text): default = self.default_medium_text else: default = self.default_char1_text row.set_keeping_parent(column_name, default) if column_name not in self._bulk_defaulted_columns: self.log.warning(f'defaulted column {column_name} to {default}') self._bulk_defaulted_columns.add(column_name) super().cache_row( row, allow_update=allow_update, allow_insert=allow_insert )
@property def insert_method(self): return self._insert_method @insert_method.setter def insert_method(self, value): self._insert_method = value if value == Table.InsertMethod.bulk_load: raise ValueError('Do not manually set bulk mode property. Use set_bulk_loader instead.') @property def update_method(self): return self._update_method @update_method.setter def update_method(self, value): self._update_method = value if value == Table.UpdateMethod.bulk_load: raise ValueError('Do not manually set bulk mode property. Use set_bulk_loader instead.') @property def delete_method(self): return self._delete_method @delete_method.setter def delete_method(self, value): self._delete_method = value if value == Table.DeleteMethod.bulk_load: raise ValueError('Do not manually set bulk mode property. Use set_bulk_loader instead.')
[docs] def autogenerate_sequence( self, row: Row, seq_column: str, force_override: bool = True, ): # Make sure we have a column object seq_column_obj = self.get_column(seq_column) # If key value is not already set, or we are supposed to force override if row.get(seq_column_obj.name) is None or force_override: next_key = self.table_key_memory.get_next_key(seq_column) row.set_keeping_parent(seq_column_obj.name, next_key) return next_key
[docs] def autogenerate_key( self, row: Row, force_override: bool = True, ): if self.auto_generate_key: if self.primary_key is None: raise ValueError("No primary key set") pk_list = list(self.primary_key) if len(pk_list) > 1: raise ValueError( f"Can't auto generate a compound key with table {self} pk={self.primary_key}" ) key = pk_list[0] return self.autogenerate_sequence(row, seq_column=key, force_override=force_override)
# noinspection PyUnresolvedReferences def _trace_data_type( self, target_name: str, t_type: TypeEngine, target_column_value: object, ): try: self.log.debug(f"{target_name} t_type={t_type}") self.log.debug(f"{target_name} t_type.precision={t_type.precision}") self.log.debug(f"{target_name} target_column_value={target_column_value}") self.log.debug( f"{target_name} get_integer_places(target_column_value)={get_integer_places(target_column_value)}" ) self.log.debug( f"{target_name} (t_type.precision - t_type.scale)={(nvl(t_type.precision, 0) - nvl(t_type.scale, 0))}" ) except AttributeError as e: self.log.error(traceback.format_exc()) self.log.debug(repr(e)) def _generate_null_check( self, target_column_object: ColumnElement ) -> str: target_name = target_column_object.name code = '' if not target_column_object.nullable: if not (self.auto_generate_key and target_name in self.primary_key): # Check for nulls. Not as an ELSE because the conversion logic might have made the value null code = textwrap.dedent( f"""\ # base indent if target_column_value is None: msg = "{self}.{target_name} has is not nullable and this cannot accept value '{{val}}'".format( val=target_column_value, ) raise ValueError(msg) """ ) return code def _get_coerce_method_name_by_str(self, target_column_name: str) -> str: if not self._coerce_methods_built: self._build_coerce_methods() return f"_coerce_{target_column_name}" def _get_coerce_method_name_by_object( self, target_column_object: Union[str, 'sqlalchemy.sql.expression.ColumnElement'] ) -> str: target_column_name = self.get_column_name(target_column_object) return self._get_coerce_method_name_by_str(target_column_name)
[docs] def get_coerce_method( self, target_column_object: Union[str, 'sqlalchemy.sql.expression.ColumnElement'] ) -> Callable: if not self._coerce_methods_built: self._build_coerce_methods() method_name = self._get_coerce_method_name_by_object(target_column_object) try: return getattr(self, method_name) except AttributeError: raise AttributeError( f'{self} does not have a coerce method for {target_column_object} ' f'- check that vs columns list {self.column_names}' )
def _make_generic_coerce( self, target_column_object: ColumnElement, ): name = self._get_coerce_method_name_by_object(target_column_object) code = f"def {name}(self, target_column_value):" code += self._generate_null_check(target_column_object) code += textwrap.dedent( """ # base indent return target_column_value """ ) try: exec(code) except SyntaxError as e: self.log.exception(f"{e} from code\n{code}") # Add the new function as a method in this class exec(f"self.{name} = {name}.__get__(self)") def _make_str_coerce( self, target_column_object: ColumnElement, ): target_name = target_column_object.name t_type = target_column_object.type name = self._get_coerce_method_name_by_object(target_column_object) code = f"def {name}(self, target_column_value):" code += textwrap.dedent( """\ # base indent if isinstance(target_column_value, str): """ ) if self.force_ascii: # Passing ascii bytes to cx_Oracle is not working. # We need to pass a str value. # So we'll use encode with 'replace' to force ascii compatibility code += textwrap.dedent( """\ # base indent target_column_value = \ target_column_value.encode('ascii', 'replace_tilda').decode('ascii') """ ) else: code += textwrap.dedent( """\ # base indent pass """ ) code += textwrap.dedent( """\ # base indent elif isinstance(target_column_value, bytes): target_column_value = target_column_value.decode('ascii') elif target_column_value is None: return None else: target_column_value = str(target_column_value) """ ) # Note: t_type.length is None for CLOB fields if t_type.length is not None: try: if t_type.length > 0: if self.database.uses_bytes_length_limits: # Encode the str as uft-8 to get the byte length code += textwrap.dedent( f"""\ # base indent value_len = len(target_column_value.encode('utf-8')) if value_len > {t_type.length}: msg = ("{self}.{target_name} has type {str(t_type).replace('"', "'")} " f"which cannot accept value '{{target_column_value}}' because " f"byte length of {{len(target_column_value.encode('utf-8'))}} > {t_type.length} limit (_make_str_coerce)" f"Note: char length is {{len(target_column_value)}}" ) raise ValueError(msg) """ ) else: code += textwrap.dedent( f"""\ # base indent value_len = len(target_column_value) if value_len > {t_type.length}: msg = ("{self}.{target_name} has type {str(t_type).replace('"', "'")} " f"which cannot accept value '{{target_column_value}}' because " f"char length of {{len(target_column_value)}} > {t_type.length} limit (_make_str_coerce)" ) raise ValueError(msg) """ ) except TypeError: # t_type.length is not a comparable type pass if isinstance(t_type, CHAR): code += textwrap.dedent( f"""\ # base indent if value_len < {t_type.length}: target_column_value += ' ' * ({t_type.length} - len(target_column_value)) """ ) code += str(self._generate_null_check(target_column_object)) code += textwrap.dedent( """ # base indent return target_column_value """ ) try: exec(code) except SyntaxError as e: self.log.exception(f"{e} from code\n{code}") # Add the new function as a method in this class exec(f"self.{name} = {name}.__get__(self)") def _make_bytes_coerce( self, target_column_object: ColumnElement, ): target_name = target_column_object.name t_type = target_column_object.type name = self._get_coerce_method_name_by_object(target_column_object) code = f"def {name}(self, target_column_value):" code += textwrap.dedent( """\ # base indent if isinstance(target_column_value, bytes): pass elif isinstance(target_column_value, str): target_column_value = target_column_value.encode('utf-8') elif target_column_value is None: return None else: target_column_value = str(target_column_value).encode('utf-8') """ ) # t_type.length is None for BLOB, LargeBinary fields. # This really might not be required since all # discovered types with python_type == bytes: # have no length if t_type.length is not None: try: if t_type.length > 0: code += textwrap.dedent( """\ # base indent if len(target_column_value) > {len}: msg = "{table}.{column} has type {type} which cannot accept value '{{val}}' because length {{val_len}} > {len} limit (_make_bytes_coerce)" msg = msg.format( val=target_column_value, val_len=len(target_column_value), ) raise ValueError(msg) """ ).format( len=t_type.length, table=self, column=target_name, type=str(t_type).replace('"', "'"), ) except TypeError: # t_type.length is not a comparable type pass code += self._generate_null_check(target_column_object) code += textwrap.dedent( """ # base indent return target_column_value """ ) try: exec(code) except SyntaxError as e: self.log.exception(f"{e} from code\n{code}") # Add the new function as a method in this class exec(f"self.{name} = {name}.__get__(self)") def _make_int_coerce( self, target_column_object: ColumnElement, ): name = self._get_coerce_method_name_by_object(target_column_object) code = f"def {name}(self, target_column_value):" code += textwrap.dedent( """\ # base indent try: if isinstance(target_column_value, int): pass elif target_column_value is None: return None elif isinstance(target_column_value, str): target_column_value = str2int(target_column_value) elif isinstance(target_column_value, float): target_column_value = int(target_column_value) except ValueError as e: msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_int_coerce)".format( val=target_column_value, e=e, ) raise ValueError(msg) """ ).format(table=self, column=target_column_object.name) code += self._generate_null_check(target_column_object) code += textwrap.dedent( """ # base indent return target_column_value """ ) try: exec(code) except SyntaxError as e: self.log.exception(f"{e} from code\n{code}") # Add the new function as a method in this class exec(f"self.{name} = {name}.__get__(self)") def _make_float_coerce( self, target_column_object: ColumnElement, ): name = self._get_coerce_method_name_by_object(target_column_object) code = f"def {name}(self, target_column_value):" # Note: str2float takes 635 ns vs 231 ns for float() but handles commas and signs. # The thought is that ETL jobs that need the performance and can guarantee no commas # can explicitly use float code += textwrap.dedent( """\ # base indent try: if isinstance(target_column_value, float): if math.isnan(target_column_value): target_column_value = self.NAN_REPLACEMENT_VALUE elif target_column_value is None: return None elif isinstance(target_column_value, str): target_column_value = str2float(target_column_value) elif isinstance(target_column_value, int): target_column_value = float(target_column_value) elif isinstance(target_column_value, Decimal): if math.isnan(target_column_value): target_column_value = self.NAN_REPLACEMENT_VALUE else: target_column_value = float(target_column_value) except ValueError as e: msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_float_coerce)".format( val=target_column_value, e=e, ) raise ValueError(msg) """ ).format(table=self, column=target_column_object.name) code += self._generate_null_check(target_column_object) code += textwrap.dedent( """ # base indent return target_column_value """ ) try: exec(code) except SyntaxError as e: self.log.exception(f"{e} from code\n{code}") # Add the new function as a method in this class exec(f"self.{name} = {name}.__get__(self)") def _make_decimal_coerce( self, target_column_object: ColumnElement, ): target_name = target_column_object.name t_type = target_column_object.type name = self._get_coerce_method_name_by_object(target_column_object) code = f"def {name}(self, target_column_value):" code += textwrap.dedent( """\ # base indent if isinstance(target_column_value, Decimal): pass elif isinstance(target_column_value, float): pass elif target_column_value is None: return None """ ) # If for performance reasons you don't want this conversion... # DON'T send in a string! # str2decimal takes 765 ns vs 312 ns for Decimal() but handles commas and signs. # The thought is that ETL jobs that need the performance and can # guarantee no commas can explicitly use float or Decimal code += textwrap.dedent( """\ # base indent elif isinstance(target_column_value, str): try: target_column_value = str2decimal(target_column_value) except ValueError as e: msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_decimal_coerce)".format( val=target_column_value, e=e, ) raise ValueError(msg) """ ).format(table=self, column=target_column_object.name) # t_type.length is None for BLOB, LargeBinary fields. # This really might not be required since all # discovered types with python_type == bytes: # have no length if t_type.precision is not None: scale = nvl(t_type.scale, 0) integer_digits_allowed = t_type.precision - scale code += textwrap.dedent( """\ # base indent if target_column_value is not None: digits = get_integer_places(target_column_value) if digits > {integer_digits_allowed}: msg = "{table}.{column} can't accept '{{val}}' since it has {{digits}} integer digits (_make_decimal_coerce)"\ "which is > {integer_digits_allowed} by (prec {precision} - scale {scale}) limit".format( val=target_column_value, digits=digits, ) raise ValueError(msg) """ ).format( table=self, column=target_name, integer_digits_allowed=integer_digits_allowed, precision=t_type.precision, scale=scale ) code += self._generate_null_check(target_column_object) code += textwrap.dedent( """ # base indent return target_column_value """ ) try: exec(code) except SyntaxError as e: self.log.exception(f"{e} from code\n{code}") # Add the new function as a method in this class exec(f"self.{name} = {name}.__get__(self)") def _make_date_coerce( self, target_column_object: ColumnElement, ): name = self._get_coerce_method_name_by_object(target_column_object) code = f"def {name}(self, target_column_value):" # Note: str2float takes 635 ns vs 231 ns for float() but handles commas and signs. # The thought is that ETL jobs that need the performance and can guarantee no commas # can explicitly use float code += textwrap.dedent( """\ # base indent try: # Note datetime check must be 1st because datetime tests as an instance of date if isinstance(target_column_value, datetime): target_column_value = date(target_column_value.year, target_column_value.month, target_column_value.day) elif isinstance(target_column_value, date): pass elif target_column_value is None: return None elif isinstance(target_column_value, str): target_column_value = str2date(target_column_value, dt_format=self.default_date_format) else: target_column_value = str2date(str(target_column_value), dt_format=self.default_date_format) except ValueError as e: msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_date_coerce) {{fmt}}".format( val=target_column_value, e=e, fmt=self.default_date_format, ) raise ValueError(msg) """ ).format(table=self, column=target_column_object.name) code += self._generate_null_check(target_column_object) code += textwrap.dedent( """ # base indent return target_column_value """ ) try: exec(code) except SyntaxError as e: self.log.exception(f"{e} from code\n{code}") # Add the new function as a method in this class exec(f"self.{name} = {name}.__get__(self)") def _make_datetime_coerce( self, target_column_object: ColumnElement, ): name = self._get_coerce_method_name_by_object(target_column_object) code = f"def {name}(self, target_column_value):" code += textwrap.dedent( """\ # base indent try: if isinstance(target_column_value, datetime): pass elif isinstance(target_column_value, timedelta): pass elif target_column_value is None: return None elif isinstance(target_column_value, date): target_column_value = datetime.combine(target_column_value, time.min) elif isinstance(target_column_value, str): target_column_value = str2datetime(target_column_value, dt_format=self.default_date_time_format) else: target_column_value = str2datetime(str(target_column_value), dt_format=self.default_date_time_format) except ValueError as e: msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_datetime_coerce) {{fmt}}".format( val=target_column_value, e=e, fmt=self.default_date_time_format, ) raise ValueError(msg) """ ).format(table=self, column=target_column_object.name) code += self._generate_null_check(target_column_object) if self.table.bind.dialect.dialect_description == 'mssql+pyodbc': # fast_executemany Currently causes this error on datetime update (dimension load) # [Microsoft][ODBC Driver 17 for SQL Server]Datetime field overflow. Fractional second precision exceeds the scale specified in the parameter binding. (0) # Also see https://github.com/sqlalchemy/sqlalchemy/issues/4418 # All because SQL Server DATETIME values are limited to 3 digits # Check for datetime2 and don't do this! if str(target_column_object.type) == 'DATETIME': self.log.warning(f"Rounding microseconds on {target_column_object}") code += textwrap.dedent( """ # base indent target_column_value = round_datetime_ms(target_column_value, 3) """ ) code += textwrap.dedent( """ # base indent return target_column_value """ ) try: code = compile(code, filename='_make_datetime_coerce', mode='exec') exec(code) except SyntaxError as e: self.log.exception(f"{e} from code\n{code}") # Add the new function as a method in this class exec(f"self.{name} = {name}.__get__(self)") def _make_time_coerce( self, target_column_object: ColumnElement, ): name = self._get_coerce_method_name_by_object(target_column_object) code = f"def {name}(self, target_column_value):" # Note: str2float takes 635 ns vs 231 ns for float() but handles commas and signs. # The thought is that ETL jobs that need the performance and can guarantee no commas # can explicitly use float code += textwrap.dedent( """\ # base indent try: if isinstance(target_column_value, time): pass elif target_column_value is None: return None elif isinstance(target_column_value, datetime): target_column_value = time(target_column_value.hour, target_column_value.minute, target_column_value.second, target_column_value.microsecond, target_column_value.tzinfo, ) elif isinstance(target_column_value, str): target_column_value = str2time(target_column_value, dt_format=self.default_time_format) else: target_column_value = str2time(str(target_column_value), dt_format=self.default_time_format) except ValueError as e: msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_time_coerce)".format( val=target_column_value, e=e, ) raise ValueError(msg) """ ).format(table=self, column=target_column_object.name) code += self._generate_null_check(target_column_object) code += textwrap.dedent( """ # base indent return target_column_value """ ) try: exec(code) except SyntaxError as e: self.log.exception(f"{e} from code\n{code}") # Add the new function as a method in this class exec(f"self.{name} = {name}.__get__(self)") def _make_timedelta_coerce( self, target_column_object: ColumnElement, ): name = self._get_coerce_method_name_by_object(target_column_object) code = f"def {name}(self, target_column_value):" # Note: str2float takes 635 ns vs 231 ns for float() but handles commas and signs. # The thought is that ETL jobs that need the performance and can guarantee no commas # can explicitly use float code += textwrap.dedent( """\ # base indent try: if isinstance(target_column_value, timedelta): pass elif target_column_value is None: return None elif isinstance(target_column_value, float): if math.isnan(target_column_value): target_column_value = self.NAN_REPLACEMENT_VALUE target_column_value = timedelta(seconds=target_column_value) elif isinstance(target_column_value, str): target_column_value = str2float(target_column_value) target_column_value = timedelta(seconds=target_column_value) elif isinstance(target_column_value, int): target_column_value = timedelta(seconds=target_column_value) else: target_column_value = timedelta(seconds=target_column_value) except (TypeError, ValueError) as e: msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_timedelta_coerce)".format( val=target_column_value, e=e, ) raise ValueError(msg) """ ).format(table=self, column=target_column_object.name) code += self._generate_null_check(target_column_object) code += textwrap.dedent( """ # base indent return target_column_value """ ) try: exec(code) except SyntaxError as e: self.log.exception(f"{e} from code\n{code}") # Add the new function as a method in this class exec(f"self.{name} = {name}.__get__(self)") def _make_bool_coerce(self, target_column_object: 'sqlalchemy.sql.expression.ColumnElement'): target_name = target_column_object.name name = self._get_coerce_method_name_by_object(target_column_object) code = f"def {name}(self, target_column_value):" # Note: str2float takes 635 ns vs 231 ns for float() but handles commas and signs. # The thought is that ETL jobs that need the performance and can guarantee no commas # can explicitly use float code += textwrap.dedent( """\ # base indent try: if isinstance(target_column_value, bool): pass elif target_column_value is None: return None elif isinstance(target_column_value, str): target_column_value = target_column_value.lower() if target_column_value in ['true', 'yes', 'y']: target_column_value = True elif target_column_value in ['false', 'no', 'n']: target_column_value = True else: type_error = True msg = "{table}.{column} unexpected value {{val}} (expected true/false, yes/no, y/n)".format( val=target_column_value, ) raise ValueError(msg) else: target_column_value = bool(target_column_value) except (TypeError, ValueError) as e: msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_bool_coerce)".format( val=target_column_value, e=e, ) raise ValueError(msg) """ ).format(table=self, column=target_name, ) code += self._generate_null_check(target_column_object) code += textwrap.dedent( """ # base indent return target_column_value """ ) try: code = compile(code, filename='_make_bool_coerce', mode='exec') exec(code) except SyntaxError as e: self.log.exception(f"{e} from code\n{code}") # Add the new function as a method in this class exec(f"self.{name} = {name}.__get__(self)") def _build_coerce_methods(self): self._coerce_methods_built = True for column in self.columns: if column.name in self.skip_coercion_on: self._make_generic_coerce(column) else: t_type = column.type try: if t_type.python_type == str: self._make_str_coerce(column) elif t_type.python_type == bytes: self._make_bytes_coerce(column) elif t_type.python_type == int: self._make_int_coerce(column) elif t_type.python_type == float: self._make_float_coerce(column) elif t_type.python_type == Decimal: self._make_decimal_coerce(column) elif t_type.python_type == date: self._make_date_coerce(column) elif t_type.python_type == datetime: self._make_datetime_coerce(column) elif t_type.python_type == time: self._make_time_coerce(column) elif t_type.python_type == timedelta: self._make_timedelta_coerce(column) elif t_type.python_type == bool: self._make_bool_coerce(column) else: warnings.warn(f'Table.build_row has no handler for {t_type} = {type(t_type)}') self._make_generic_coerce(column) except NotImplementedError: raise ValueError(f"Column {column} has type {t_type} with no python_type")
[docs] def column_coerce_type( self, target_column_object: ColumnElement, target_column_value: object, ): """ This is the slower non-dynamic code based data type conversion routine """ target_name = target_column_object.name if target_column_value is not None: t_type = target_column_object.type type_error = False err_msg = None try: if t_type.python_type == str: if isinstance(target_column_value, str): if self.force_ascii: # Passing ascii bytes to cx_Oracle is not working. # We need to pass a str value. # So we'll use encode with 'replace' to force ascii compatibility target_column_value = \ target_column_value.encode('ascii', 'replace_tilda').decode('ascii') elif isinstance(target_column_value, bytes): target_column_value = target_column_value.decode('ascii') else: target_column_value = str(target_column_value) # Note: t_type.length is None for CLOB fields try: if t_type.length is not None and len(target_column_value) > t_type.length: type_error = True err_msg = f"length {len(target_column_value)} > {t_type.length} limit" except TypeError: # t_type.length is not a comparable type pass elif t_type.python_type == bytes: if isinstance(target_column_value, str): target_column_value = target_column_value.encode('utf-8') elif isinstance(target_column_value, bytes): pass # We don't update the value else: target_column_value = str(target_column_value).encode('utf-8') # t_type.length is None for BLOB, LargeBinary fields. # This really might not be required since all # discovered types with python_type == bytes: # have no length if t_type.length is not None: if len(target_column_value) > t_type.length: type_error = True err_msg = f"{len(target_column_value)} > {t_type.length}" elif t_type.python_type == int: if isinstance(target_column_value, str): # Note: str2int takes 590 ns vs 220 ns for int() but handles commas and signs. target_column_value = str2int(target_column_value) elif t_type.python_type == float: # noinspection PyTypeChecker if isinstance(target_column_value, str): # Note: str2float takes 635 ns vs 231 ns for float() but handles commas and signs. # The thought is that ETL jobs that need the performance and can guarantee no commas # can explicitly use float target_column_value = str2float(target_column_value) elif math.isnan(float(target_column_value)): target_column_value = self.NAN_REPLACEMENT_VALUE elif t_type.python_type == Decimal: if isinstance(target_column_value, str): # If for performance reasons you don't want this conversion... # DON'T send in a string! # str2decimal takes 765 ns vs 312 ns for Decimal() but handles commas and signs. # The thought is that ETL jobs that need the performance and can # guarantee no commas can explicitly use float or Decimal target_column_value = str2decimal(target_column_value) elif isinstance(target_column_value, float): if math.isnan(target_column_value): target_column_value = self.NAN_REPLACEMENT_VALUE if t_type.precision is not None: scale = nvl(t_type.scale, 0) if get_integer_places(target_column_value) > (t_type.precision - scale): type_error = True err_msg = ( f"{get_integer_places(target_column_value)} " f"digits > {(t_type.precision - scale)} = " f"(prec {t_type.precision} - scale {t_type.scale}) limit" ) elif t_type.python_type == date: # If we already have a datetime, make it a date if isinstance(target_column_value, datetime): target_column_value = date( target_column_value.year, target_column_value.month, target_column_value.day ) # If we already have a date elif isinstance(target_column_value, date): pass else: target_column_value = str2date( str(target_column_value), dt_format=self.default_date_format ) elif t_type.python_type == datetime: # If we already have a date or datetime value if isinstance(target_column_value, datetime): if 'mssql' in self.table.bind.dialect.dialect_description: # fast_executemany Currently causes this error on datetime update (dimension load) # [Microsoft][ODBC Driver 17 for SQL Server]Datetime field overflow. Fractional second precision exceeds the scale specified in the parameter binding. (0) # Also see https://github.com/sqlalchemy/sqlalchemy/issues/4418 # All because SQL Server DATETIME values are limited to 3 digits # Check for datetime2 and don't do this! if str(target_column_object.type) == 'DATETIME': target_column_value = round_datetime_ms(target_column_value, 3) elif isinstance(target_column_value, str): target_column_value = str2datetime( target_column_value, dt_format=self.default_date_time_format ) else: target_column_value = str2datetime( str(target_column_value), dt_format=self.default_date_time_format ) elif t_type.python_type == time: # If we already have a datetime, make it a time if isinstance(target_column_value, datetime): target_column_value = time( target_column_value.hour, target_column_value.minute, target_column_value.second, target_column_value.microsecond, target_column_value.tzinfo, ) # If we already have a date or time value elif isinstance(target_column_value, time): pass else: target_column_value = str2time( str(target_column_value), dt_format=self.default_time_format ) elif t_type.python_type == timedelta: # If we already have an interval value if isinstance(target_column_value, timedelta): pass else: # noinspection PyTypeChecker target_column_value = timedelta(seconds=float(target_column_value)) elif t_type.python_type == bool: if isinstance(target_column_value, bool): pass elif isinstance(target_column_value, str): target_column_value = target_column_value.lower() if target_column_value in ['true', 'yes', 'y']: target_column_value = True elif target_column_value in ['false', 'no', 'n']: target_column_value = True else: type_error = True err_msg = "unexpected value (expected true/false, yes/no, y/n)" else: warnings.warn(f'Table.build_row has no handler for {t_type} = {type(t_type)}') except (ValueError, InvalidOperation) as e: self.log.error(repr(e)) err_msg = repr(e) type_error = True if type_error: msg = ( f"{self}.{target_name} has type {t_type} " f"which cannot accept value '{target_column_value}' {err_msg}" ) raise ValueError(msg) # Check for nulls. Not as an ELSE because the conversion logic might have made the value null if target_column_value is None: if not target_column_object.nullable: if not self.auto_generate_key and target_name in self.primary_key: msg = ( f"{self}.{target_name} has is not nullable " f"and this cannot accept value '{target_column_value}'" ) raise ValueError(msg) # End if target_column_value is not None: return target_column_value
@staticmethod def _get_build_row_name(row: Row) -> str: return f"_build_row_{row.iteration_header.iteration_id}"
[docs] def sanity_check_source_mapping( self, source_definition: ETLComponent, source_name: str = None, source_excludes: Optional[frozenset] = None, target_excludes: Optional[frozenset] = None, ignore_source_not_in_target: Optional[bool] = None, ignore_target_not_in_source: Optional[bool] = None, raise_on_source_not_in_target: Optional[bool] = None, raise_on_target_not_in_source: Optional[bool] = None, ): if target_excludes is None: target_excludes = set() else: target_excludes = set(target_excludes) if self.auto_generate_key: if self.primary_key is not None: # Don't complain about the PK column being missing # Note if PK is more than one column it will fail in autogenerate_key. Doesn't need to be caught here. key = self.get_column_name(list(self.primary_key)[0]) if key not in source_definition: target_excludes.add(key) else: raise KeyError('Cannot generate key values without a primary key.') if self.delete_flag is not None and self.delete_flag not in source_definition: target_excludes.add(self.delete_flag) if self.last_update_date is not None and self.last_update_date not in source_definition: target_excludes.add(self.last_update_date) target_excludes = frozenset(target_excludes) super().sanity_check_source_mapping( source_definition=source_definition, source_name=source_name, source_excludes=source_excludes, target_excludes=target_excludes, ignore_source_not_in_target=ignore_source_not_in_target, ignore_target_not_in_source=ignore_target_not_in_source, raise_on_source_not_in_target=raise_on_source_not_in_target, raise_on_target_not_in_source=raise_on_target_not_in_source, )
[docs] def sanity_check_example_row( self, example_source_row, source_excludes: Optional[frozenset] = None, target_excludes: Optional[frozenset] = None, ignore_source_not_in_target: Optional[bool] = None, ignore_target_not_in_source: Optional[bool] = None, ): self.sanity_check_source_mapping( example_source_row, example_source_row.name, source_excludes=source_excludes, target_excludes=target_excludes, ignore_source_not_in_target=ignore_source_not_in_target, ignore_target_not_in_source=ignore_target_not_in_source, )
# TODO: Sanity check primary key data types. # Lookups might fail if the types don't match (although build_row in safe_mode should fix it) def _insert_stmt(self): # pylint: disable=no-value-for-parameter ins = self.table.insert() if self.insert_hint is not None: ins.with_hint(self.insert_hint) return ins def _insert_pending_batch( self, stat_name: str = 'insert', parent_stats: Optional[Statistics] = None, ): if self.in_bulk_mode: raise InvalidOperation('_insert_pending_batch is not allowed for bulk loader') # Need to delete pending first in case we are doing delete & insert pairs self._delete_pending_batch(parent_stats=parent_stats) # Need to update pending first in case we are doing update & insert pairs self._update_pending_batch(parent_stats=parent_stats) if len(self.pending_insert_rows) == 0: return if parent_stats is not None: stats = parent_stats # Keep track of which parent last used pending inserts self.pending_insert_stats = stats else: stats = self.pending_insert_stats if stats is None: stats = self.get_stats_entry(stat_name, parent_stats=parent_stats) if self.insert_method == Table.InsertMethod.insert_values_list: prepare_stats = self.get_stats_entry('prepare ' + stat_name, parent_stats=stats) prepare_stats.print_start_stop_times = False prepare_stats.timer.start() pending_insert_statements = StatementQueue(execute_with_binds=False) for new_row in self.pending_insert_rows: if new_row.status != RowStatus.deleted: stmt_key = new_row.positioned_column_set stmt = pending_insert_statements.get_statement_by_key(stmt_key) if stmt is None: prepare_stats['statements prepared'] += 1 stmt = f"INSERT INTO {self.qualified_table_name} ({', '.join(new_row.columns_in_order)}) VALUES {{values}}" pending_insert_statements.add_statement(stmt_key, stmt) prepare_stats['rows prepared'] += 1 stmt_values = new_row.values() pending_insert_statements.append_values_by_key(stmt_key, stmt_values) # Change the row status to existing, now any updates should be via update statements instead of in-memory updates new_row.status = RowStatus.existing prepare_stats.timer.stop() # TODO: Slow performance on big varchars -- possible cause: # By default pyodbc reads all text columns as SQL_C_WCHAR buffers and decodes them using UTF-16LE. # When writing, it always encodes using UTF-16LE into SQL_C_WCHAR buffers. # Connection.setencoding and Connection.setdecoding functions can override try: db_stats = self.get_stats_entry(stat_name + ' database execute', parent_stats=stats) db_stats.print_start_stop_times = False db_stats.timer.start() rows_affected = pending_insert_statements.execute(self.connection()) db_stats.timer.stop() db_stats['rows inserted'] += rows_affected del self.pending_insert_rows self.pending_insert_rows = list() except Exception as e: # self.log.error(traceback.format_exc()) self.log.exception(e) self.log.error("Bulk insert failed. Applying as single inserts to find error row...") self.rollback() self.begin() # Retry one at a time try: pending_insert_statements.execute_singly(self.connection()) # If that didn't cause the error... re raise the original error self.log.error( f"Single inserts on {self} did not produce the error. Original error will be issued below." ) self.rollback() raise except Exception as e_single: self.log.error(f"Single inserts got error {e_single}") self.log.exception(e_single) raise else: prepare_stats = self.get_stats_entry('prepare ' + stat_name, parent_stats=stats) prepare_stats.print_start_stop_times = False prepare_stats.timer.start() pending_insert_statements = StatementQueue() for new_row in self.pending_insert_rows: if new_row.status != RowStatus.deleted: stmt_key = new_row.positioned_column_set stmt = pending_insert_statements.get_statement_by_key(stmt_key) if stmt is None: prepare_stats['statements prepared'] += 1 stmt = self._insert_stmt() for c in new_row: try: col_obj = self.get_column(c) stmt = stmt.values({col_obj.name: bindparam(col_obj.name, type_=col_obj.type)}) except KeyError: self.log.error(f'Extra column found in pending_insert_rows row {new_row}') raise # this was causing InterfaceError: (cx_Oracle.InterfaceError) not a query # stmt = stmt.compile() pending_insert_statements.add_statement(stmt_key, stmt) prepare_stats['rows prepared'] += 1 stmt_values = dict() for c in new_row: col_obj = self.get_column(c) stmt_values[col_obj.name] = new_row[c] pending_insert_statements.append_values_by_key(stmt_key, stmt_values) # Change the row status to existing, now any updates should be via update statements new_row.status = RowStatus.existing prepare_stats.timer.stop() try: db_stats = self.get_stats_entry(stat_name + ' database execute', parent_stats=stats) db_stats.print_start_stop_times = False db_stats.timer.start() rows_affected = pending_insert_statements.execute(self.connection()) db_stats.timer.stop() db_stats['rows inserted'] += rows_affected del self.pending_insert_rows self.pending_insert_rows = list() except SQLAlchemyError as e: self.log.error( f"Bulk insert on {self} failed with error {e}. Applying as single inserts to find error row..." ) self.rollback() self.begin() # Retry one at a time try: pending_insert_statements.execute_singly(self.connection()) # If that didn't cause the error... re raise the original error self.log.error( f"Single inserts on {self} did not produce the error. Original error will be issued below." ) self.rollback() raise except Exception as e_single: self.log.error(f"Single inserts got error {e_single}") self.log.exception(e_single) raise
[docs] def insert_row( self, source_row: Row, additional_insert_values: Optional[dict] = None, maintain_cache: Optional[bool] = None, source_excludes: Optional[frozenset] = None, target_excludes: Optional[frozenset] = None, stat_name: str = 'insert', parent_stats: Optional[Statistics] = None, ) -> Row: """ Inserts a row into the database (batching rows as batch_size) Parameters ---------- maintain_cache source_row The row with values to insert additional_insert_values: Values to add / override in the row before inserting. source_excludes: set of source columns to exclude target_excludes set of target columns to exclude stat_name parent_stats Returns ------- new_row """ stats = self.get_stats_entry(stat_name, parent_stats=parent_stats) stats.timer.start() self.begin() new_row = self.build_row( source_row=source_row, source_excludes=source_excludes, target_excludes=target_excludes, parent_stats=stats, ) self.autogenerate_key(new_row, force_override=False) if self.delete_flag is not None: if self.delete_flag not in new_row or new_row[self.delete_flag] is None: new_row.set_keeping_parent(self.delete_flag, self.delete_flag_no) # Set the last update date if self.last_update_date is not None: self.set_last_update_date(new_row) if self.trace_data: self.log.debug(f"{self} Raw row being inserted:\n{new_row.str_formatted()}") if self.in_bulk_mode: if not self.upsert_called: if self._bulk_iter_sentinel is None: self.log.info(f"Setting up {self} worker to collect insert_row rows") # It does make sense to default to not caching inserted rows # UNLESS this table is also used as a lookup. Lookups should have been defined. if len(self.lookups) == 0: self.maintain_cache_during_load = False self._bulk_iter_sentinel = StopIteration self._bulk_iter_queue = Queue() # row_iter = iter(self._bulk_iter_queue.get, self._bulk_iter_sentinel) row_iter = self._bulk_iter_queue self._bulk_iter_worker = spawn( self.bulk_loader.load_from_iterable, row_iter, table_object=self, progress_frequency=None, ) self._bulk_iter_queue.put(new_row) self._bulk_iter_max_q_length = max(self._bulk_iter_max_q_length, len(self._bulk_iter_queue)) # Allow worker to act if it is ready sleep() else: if self.batch_size > 1: new_row.status = RowStatus.insert self.pending_insert_rows.append(new_row) if len(self.pending_insert_rows) >= self.batch_size: self._insert_pending_batch(parent_stats=stats) else: # Immediate insert new_row.status = RowStatus.existing try: stmt_values = dict() for c in new_row: col_obj = self.get_column(c) stmt_values[col_obj.name] = new_row[c] result = self.execute(self._insert_stmt(), stmt_values) except Exception as e: self.log.error(f"Error with source_row = {source_row}") self.log.error(f"Error with inserted row = {new_row.str_formatted()}") raise e stats['rows inserted'] += result.rowcount result.close() if maintain_cache is None: maintain_cache = self.maintain_cache_during_load if maintain_cache: self.cache_row(new_row, allow_update=False) stats.timer.stop() return new_row
[docs] def insert( self, source_row: Union[MutableMapping, Iterable[MutableMapping]], additional_insert_values: Optional[dict] = None, source_excludes: Optional[frozenset] = None, target_excludes: Optional[frozenset] = None, parent_stats: Optional[Statistics] = None, **kwargs ): """ Insert a row or list of rows in the table. Parameters ---------- source_row: :class:`Row` or list thereof Row(s) to insert additional_insert_values: dict Additional values to set on each row. source_excludes list of Row source columns to exclude when mapping to this Table. target_excludes list of Table columns to exclude when mapping from the source Row(s) parent_stats: bi_etl.statistics.Statistics Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. """ if isinstance(source_row, MutableMapping): if not isinstance(source_row, self.row_object): warnings.warn('Passing types other than Row to insert is slower.') source_row = self.row_object(data=source_row, iteration_header=self.empty_iteration_header) return self.insert_row( source_row, additional_insert_values=additional_insert_values, source_excludes=source_excludes, target_excludes=target_excludes, parent_stats=parent_stats, **kwargs ) else: for row in source_row: if not isinstance(row, self.row_object): warnings.warn('Passing types other than Row to insert is slower.') row = self.row_object(data=row, iteration_header=self.empty_iteration_header) self.insert_row( row, additional_insert_values=additional_insert_values, source_excludes=source_excludes, target_excludes=target_excludes, parent_stats=parent_stats, **kwargs ) return None
def _delete_pending_batch( self, stat_name: str = 'delete', parent_stats: Iterable[Statistics] = None, ): if self.in_bulk_mode: raise InvalidOperation('_delete_pending_batch not allowed in bulk mode') if self.pending_delete_statements.row_count > 0: if parent_stats is not None: stats = parent_stats # Keep track of which parent last used pending delete self.pending_delete_stats = stats else: stats = self.pending_delete_stats if stats is None: stats = self.get_stats_entry(stat_name, parent_stats=parent_stats) stats['rows batch deleted'] += self.pending_delete_statements.execute(self.connection()) def _delete_stmt(self): # pylint: disable=no-value-for-parameter return self.table.delete()
[docs] def delete( self, key_values: Union[Iterable, MutableMapping], lookup_name: Optional[str] = None, key_names: Optional[Iterable] = None, maintain_cache: Optional[bool] = None, stat_name: str = 'delete', parent_stats: Optional[Statistics] = None, ): """ Delete rows matching key_values. If caching is enabled, don't use key_names (use lookup_name or default to PK) or an expensive scan of the cache needs to be performed. Parameters ---------- key_values: dict (or list) Container of key values to use to identify a row(s) to delete. lookup_name: str Optional lookup name those key values are from. key_names Optional list of column names those key values are from. maintain_cache: boolean Maintain the cache when deleting. Can be expensive if key_names is used because scan of the entire cache needs to be performed. Defaults to True stat_name: string Name of this step for the ETLTask statistics. parent_stats: bi_etl.statistics.Statistics Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. """ stats = self.get_stats_entry(stat_name, parent_stats=parent_stats) stats.timer.start() key_values_dict, lookup_name = self._generate_key_values_dict(key_names, key_values, lookup_name) # Don't use key_names or key_values anymore, use key_values_dict del key_names del key_values if self.in_bulk_mode: maintain_cache = True else: self.begin() if maintain_cache is None: maintain_cache = self.maintain_cache_during_load and self.cache_clean if not maintain_cache: self.cache_clean = False if self.batch_size > 1: # Bacth-ed deletes # This mode is difficult to buffer. We might have multiple delete statements used, so we have # to maintain a dictionary of statements by tuple(key_names) and a buffer for each delete_stmt_key = frozenset(key_values_dict.keys()) stmt = self.pending_delete_statements.get_statement_by_key(delete_stmt_key) if stmt is None: stmt = self._delete_stmt() for key_name, key_value in key_values_dict.items(): key = self.get_column(key_name) stmt = stmt.where(key == bindparam(key.name, type_=key.type)) stmt = stmt.compile() self.pending_delete_statements.add_statement(delete_stmt_key, stmt) self.pending_delete_statements.append_values_by_key(delete_stmt_key, key_values_dict) if len(self.pending_delete_statements) >= self.batch_size: self._delete_pending_batch(parent_stats=stats) else: # Deletes issued as we get them stmt = self._delete_stmt() for key_name, key_value in key_values_dict.items(): key = self.get_column(key_name) stmt = stmt.where(key == key_value) del_result = self.execute(stmt) stats['rows deleted'] += del_result.rowcount del_result.close() if maintain_cache: if lookup_name is not None: full_row = self.get_by_lookup(lookup_name, key_values_dict, parent_stats=stats) self.uncache_row(full_row) else: key_columns_set = set(key_values_dict.keys()) found_matching_lookup = False for lookup in self.lookups.values(): if set(lookup.lookup_keys) == key_columns_set: found_matching_lookup = True full_row = lookup.find(key_values_dict, stats=stats) self.uncache_row(full_row) if not found_matching_lookup: warnings.warn( f'{self}.delete called with maintain_cache=True and a set of keys that do not match any lookup.' f'This will be very slow. Keys used = {key_columns_set}' ) self.uncache_where(key_names=key_values_dict.keys(), key_values_dict=key_values_dict) stats.timer.stop()
# TODO: Add lookup parameter and change to allow set of be of that instead of keys
[docs] def delete_not_in_set( self, set_of_key_tuples: set, lookup_name: Optional[str] = None, criteria_list: Optional[list] = None, criteria_dict: Optional[dict] = None, use_cache_as_source: Optional[bool] = True, stat_name: str = 'delete_not_in_set', progress_frequency: Optional[int] = None, parent_stats: Optional[Statistics] = None, **kwargs ): """ WARNING: This does physical deletes !! See :meth:`logically_delete_in_set` for logical deletes. Deletes rows matching criteria that are not in the list_of_key_tuples pass in. Parameters ---------- set_of_key_tuples List of tuples comprising the primary key values. This list represents the rows that should *not* be deleted. lookup_name: str The name of the lookup to use to find key tuples. criteria_list : string or list of strings Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`. https://goo.gl/JlY9us criteria_dict : dict Dict keys should be columns, values are set using = or in use_cache_as_source: bool Attempt to read existing rows from the cache? stat_name: string Name of this step for the ETLTask statistics. Default = 'delete_not_in_set' progress_frequency: int How often (in seconds) to output progress messages. Default 10. None for no progress messages. parent_stats: bi_etl.statistics.Statistics Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. """ stats = self.get_unique_stats_entry(stat_name, parent_stats=parent_stats) stats['rows read'] = 0 stats['rows deleted'] = 0 stats.timer.start() deleted_rows = list() self.begin() if progress_frequency is None: progress_frequency = self.progress_frequency progress_timer = Timer() # Turn off read progress reports saved_progress_frequency = self.progress_frequency self.progress_frequency = None if lookup_name is None: lookup = self.get_nk_lookup() else: lookup = self.get_lookup(lookup_name) # Builds static list of cache content so that we can modify (physically delete) cache rows # while iterating over row_list = list( self.where( criteria_list=criteria_list, criteria_dict=criteria_dict, use_cache_as_source=use_cache_as_source, parent_stats=stats ) ) for row in row_list: stats['rows read'] += 1 existing_key = lookup.get_hashable_combined_key(row) if 0 < progress_frequency <= progress_timer.seconds_elapsed: progress_timer.reset() self.log.info( f"delete_not_in_set current row={stats['rows read']} key={existing_key} deletes_done = {stats['rows deleted']}" ) if existing_key not in set_of_key_tuples: stats['rows deleted'] += 1 deleted_rows.append(row) # In bulk mode we just need to remove them from the cache, # which is done below where we loop over deleted_rows if not self.in_bulk_mode: self.delete(key_values=row, lookup_name=self._get_pk_lookup_name(), maintain_cache=False) for row in deleted_rows: self.uncache_row(row) stats.timer.stop() # Restore saved progress_frequency self.progress_frequency = saved_progress_frequency
[docs] def delete_not_processed( self, criteria_list: Optional[list] = None, criteria_dict: Optional[dict] = None, use_cache_as_source: bool = True, allow_delete_all: bool = False, stat_name: str = 'delete_not_processed', parent_stats: Optional[Statistics] = None, **kwargs ): """ WARNING: This does physical deletes !! See :meth:`logically_delete_not_processed` for logical deletes. Physically deletes rows matching criteria that are not in the Table memory of rows passed to :meth:`upsert`. Parameters ---------- criteria_list : string or list of strings Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`. https://goo.gl/JlY9us criteria_dict : dict Dict keys should be columns, values are set using = or in use_cache_as_source: bool Attempt to read existing rows from the cache? allow_delete_all: Allow this method to delete ALL rows if no source rows were processed stat_name: string Name of this step for the ETLTask statistics. parent_stats: bi_etl.statistics.Statistics Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. """ assert self.track_source_rows, "delete_not_processed can't be used if we don't track source rows" if not allow_delete_all and (self.source_keys_processed is None or len(self.source_keys_processed) == 0): # We don't want to logically delete all the rows # But that's only an issue if there are target rows if any(True for _ in self.where(criteria_list=criteria_list, criteria_dict=criteria_dict, )): raise RuntimeError(f"{stat_name} called before any source rows were processed on {self}.") self.delete_not_in_set( set_of_key_tuples=self.source_keys_processed, criteria_list=criteria_list, criteria_dict=criteria_dict, stat_name=stat_name, parent_stats=parent_stats ) self.source_keys_processed = set()
[docs] def logically_delete_not_in_set( self, set_of_key_tuples: set, lookup_name: Optional[str] = None, criteria_list: Optional[list] = None, criteria_dict: Optional[dict] = None, use_cache_as_source: bool = True, stat_name: str = 'logically_delete_not_in_set', progress_frequency: Optional[int] = 10, parent_stats: Optional[Statistics] = None, **kwargs ): """ Logically deletes rows matching criteria that are not in the list_of_key_tuples pass in. Parameters ---------- set_of_key_tuples List of tuples comprising the primary key values. This list represents the rows that should *not* be logically deleted. lookup_name: Name of the lookup to use criteria_list : Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`. https://goo.gl/JlY9us criteria_dict : Dict keys should be columns, values are set using = or in use_cache_as_source: bool Attempt to read existing rows from the cache? stat_name: Name of this step for the ETLTask statistics. Default = 'delete_not_in_set' progress_frequency: How often (in seconds) to output progress messages. Default = 10. parent_stats: Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. kwargs: IF child class HistoryTable effective_date: The effective date to use for this operation. """ if self._logical_delete_update is None: iteration_header = self.generate_iteration_header( logical_name='logically_deleted', columns_in_order=[self.delete_flag], ) self._logical_delete_update = self.row_object(iteration_header=iteration_header) self._logical_delete_update[self.delete_flag] = self.delete_flag_yes if criteria_dict is None: criteria_dict = dict() # Do no process rows that are already deleted criteria_dict = {self.delete_flag: self.delete_flag_no} self.update_not_in_set( updates_to_make=self._logical_delete_update, set_of_key_tuples=set_of_key_tuples, lookup_name=lookup_name, criteria_list=criteria_list, criteria_dict=criteria_dict, use_cache_as_source=use_cache_as_source, progress_frequency=progress_frequency, stat_name=stat_name, parent_stats=parent_stats, **kwargs )
[docs] def logically_delete_not_processed( self, criteria_list: Optional[list] = None, criteria_dict: Optional[dict] = None, use_cache_as_source: bool = True, allow_delete_all: bool = False, stat_name='logically_delete_not_processed', parent_stats: Optional[Statistics] = None, **kwargs ): """ Logically deletes rows matching criteria that are not in the Table memory of rows passed to :meth:`upsert`. Parameters ---------- criteria_list: Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`. https://goo.gl/JlY9us criteria_dict: Dict keys should be columns, values are set using = or in use_cache_as_source: bool Attempt to read existing rows from the cache? allow_delete_all: Allow this method to delete all rows. Defaults to False in case an error preventing the processing of any rows. parent_stats: bi_etl.statistics.Statistics Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. stat_name: string Name of this step for the ETLTask statistics. Default = 'logically_delete_not_processed' """ if not self.track_source_rows: raise ValueError( "logically_delete_not_processed can't be used if we don't track source rows" ) if not allow_delete_all and (self.source_keys_processed is None or len(self.source_keys_processed) == 0): # We don't want to logically delete all the rows # But that's only an issue if there are target rows if any(True for _ in self.where(criteria_list=criteria_list, criteria_dict=criteria_dict, )): raise RuntimeError(f"{stat_name} called before any source rows were processed on {self}.") self.logically_delete_not_in_set( set_of_key_tuples=self.source_keys_processed, criteria_list=criteria_list, criteria_dict=criteria_dict, stat_name='logically_delete_not_processed', parent_stats=parent_stats, **kwargs ) self.source_keys_processed = set()
[docs] def logically_delete_not_in_source( self, source: ReadOnlyTable, source_criteria_list: Optional[list] = None, source_criteria_dict: Optional[dict] = None, target_criteria_list: Optional[list] = None, target_criteria_dict: Optional[dict] = None, use_cache_as_source: Optional[bool] = True, parent_stats: Optional[Statistics] = None ): """ Logically deletes rows matching criteria that are not in the source component passed to this method. The primary use case for this method is when the upsert method is only passed new/changed records and so cannot build a complete set of source keys in source_keys_processed. Parameters ---------- source: The source to read to get the source keys. source_criteria_list : string or list of strings Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`. https://goo.gl/JlY9us source_criteria_dict: Dict keys should be columns, values are set using = or in target_criteria_list : string or list of strings Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`. https://goo.gl/JlY9us target_criteria_dict: Dict keys should be columns, values are set using = or in use_cache_as_source: Attempt to read existing rows from the cache? parent_stats: Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. """ self.log.info("Processing deletes") self.log.info("...getting source keys") set_of_source_keys = set() for row in source.where( column_list=source.primary_key, criteria_list=source_criteria_list, criteria_dict=source_criteria_dict, parent_stats=parent_stats, ): set_of_source_keys.add(source.get_primary_key_value_tuple(row)) self.log.info("...logically_delete_not_in_set of source keys") self.logically_delete_not_in_set( set_of_source_keys, criteria_list=target_criteria_list, criteria_dict=target_criteria_dict, use_cache_as_source=use_cache_as_source, parent_stats=parent_stats, )
[docs] def get_current_time(self) -> datetime: if self.use_utc_times: # Returns the current UTC date and time, as a naive datetime object. return datetime.utcnow() else: return datetime.now()
[docs] def set_last_update_date(self, row): last_update_date_col = self.get_column_name(self.last_update_date) last_update_coerce = self.get_coerce_method(last_update_date_col) last_update_value = last_update_coerce(self.get_current_time()) row.set_keeping_parent(last_update_date_col, last_update_value)
[docs] def get_bind_name(self, column_name: str) -> str: if self._bind_name_map is None: max_len = self.database.dialect.max_identifier_length self._bind_name_map = dict() bind_names = set() for other_column_name in self.column_names: bind_name = f'b_{other_column_name}' if len(bind_name) > max_len: bind_name = bind_name[:max_len - 2] unique_name_found = False next_disamb = 0 while not unique_name_found: if next_disamb >= 1: bind_name = bind_name[:max_len - len(str(next_disamb))] + str(next_disamb) if bind_name in bind_names: next_disamb += 1 else: unique_name_found = True bind_names.add(bind_name) self._bind_name_map[other_column_name] = bind_name return self._bind_name_map[column_name] else: return self._bind_name_map[column_name]
def _update_pending_batch( self, stat_name: str = 'update', parent_stats: Optional[Statistics] = None ): """ Parameters ---------- stat_name: str Name of this step for the ETLTask statistics. Default = 'update' parent_stats: Statistics Returns ------- """ if self.in_bulk_mode: raise InvalidOperation('_update_pending_batch not allowed in bulk mode') if len(self.pending_update_rows) == 0: return assert self.primary_key, "add_pending_update called for table with no primary key" if parent_stats is not None: stats = parent_stats # Keep track of which parent stats last called us self.pending_update_stats = stats else: stats = self.pending_update_stats if stats is None: stats = self.get_stats_entry(stat_name, parent_stats=parent_stats) stats.timer.start() self.begin() pending_update_statements = StatementQueue() for pending_update_rec in self.pending_update_rows.values(): row_dict = pending_update_rec.update_row_new_values if row_dict.status not in [RowStatus.deleted, RowStatus.insert]: update_where_values = pending_update_rec.update_where_values update_where_keys = pending_update_rec.update_where_keys stats['update rows sent to db'] += 1 update_stmt_key = (update_where_keys, row_dict.column_set) update_stmt = pending_update_statements.get_statement_by_key(update_stmt_key) if update_stmt is None: update_stmt = self._update_stmt() for key_column in update_where_keys: key = self.get_column(key_column) bind_name = f"_{key.name}" # Note SQLAlchemy takes care of converting to positional “qmark” bind parameters as needed update_stmt = update_stmt.where(key == bindparam(bind_name, type_=key.type)) for c in row_dict.columns_in_order: col_obj = self.get_column(c) bind_name = col_obj.name update_stmt = update_stmt.values({c: bindparam(bind_name, type_=col_obj.type)}) update_stmt = update_stmt.compile() pending_update_statements.add_statement(update_stmt_key, update_stmt) stmt_values = dict() for key_column, key_value in zip(update_where_keys, update_where_values): key = self.get_column(key_column) bind_name = f"_{key.name}" stmt_values[bind_name] = key_value for c in row_dict.columns_in_order: # Since we know we aren't updating the key, don't send the keys in the values clause col_obj = self.get_column(c) bind_name = col_obj.name stmt_values[bind_name] = row_dict[c] pending_update_statements.append_values_by_key(update_stmt_key, stmt_values) row_dict.status = RowStatus.existing try: db_stats = self.get_stats_entry('DB Update', parent_stats=stats) db_stats['update statements sent to db'] += 1 db_stats.timer.start() rows_applied = pending_update_statements.execute(self.connection()) db_stats.timer.stop() db_stats['applied rows'] += rows_applied # len(self.pending_update_rows) except Exception as e: # Retry one at a time self.log.error( f"Bulk update on {self} failed with error {e}. Applying as single updates to find error row..." ) connection = self.connection(connection_name='singly') for (stmt, row_dict) in pending_update_statements.iter_single_statements(): try: connection.execute(stmt, row_dict) except Exception as e: self.log.error(f"Error with row {dict_to_str(row_dict)}") raise e # If that didn't cause the error... re-raise the original error self.log.error(f"Single updates on {self} did not produce the error. Original error will be issued below.") self.rollback() raise e self.pending_update_rows.clear() stats.timer.stop() def _add_pending_update( self, update_where_keys: Tuple[str], update_where_values: Iterable[Any], update_row_new_values: Row, parent_stats: Optional[Statistics] = None, ): assert self.primary_key, "add_pending_update called for table with no primary key" assert update_row_new_values.status != RowStatus.insert, "add_pending_update called with row that's pending insert" assert update_row_new_values.status != RowStatus.deleted, "add_pending_update called with row that's pending delete" key_tuple = tuple(update_where_values) pending_update_rec = PendingUpdate( update_where_keys, update_where_values, update_row_new_values ) if key_tuple not in self.pending_update_rows: self.pending_update_rows[key_tuple] = pending_update_rec else: # Apply updates to the existing pending update existing_update_rec = self.pending_update_rows[key_tuple] existing_row = existing_update_rec.update_row_new_values for col, value in update_row_new_values.items(): existing_row[col] = value if len(self.pending_update_rows) >= self.batch_size: self._update_pending_batch(parent_stats=parent_stats)
[docs] def apply_updates( self, row: Row, changes_list: Optional[MutableSequence[ColumnDifference]] = None, additional_update_values: MutableMapping = None, add_to_cache: bool = True, allow_insert=True, update_apply_entire_row: bool = False, stat_name: str = 'update', parent_stats: Optional[Statistics] = None, **kwargs ): """ This method should only be called with a row that has already been transformed into the correct datatypes and column names. The update values can be in any of the following parameters - row (also used for PK) - changes_list - additional_update_values Parameters ---------- row: The row to update with (needs to have at least PK values) The values are used for the WHERE and so should be before any updates to make. The updates to make can be sent via changes_list or additional_update_values. changes_list: A list of ColumnDifference objects to apply to the row additional_update_values: A Row or dict of additional values to apply to the main row add_to_cache: Should this method update the cache (not if caller will) allow_insert: boolean Allow this method to insert a new row into the cache update_apply_entire_row: Should the update set all non-key columns or just the changed values? stat_name: Name of this step for the ETLTask statistics. Default = 'update' parent_stats: kwargs: Not used in this version """ assert self.primary_key, "apply_updates called for table with no primary key" assert row.status != RowStatus.deleted, f"apply_updates called for deleted row {row}" stats = self.get_stats_entry(stat_name, parent_stats=parent_stats) self.pending_update_stats = stats stats.timer.start() # Set the last update date if self.last_update_date is not None: self.set_last_update_date(row) # Use row to get PK where values BEFORE we apply extra updates to the row # This allows the update WHERE and SET to have different values for the PK fields update_where_keys = self.primary_key_tuple update_where_values = [row[c] for c in update_where_keys] if changes_list is not None: for d_chg in changes_list: if isinstance(d_chg, ColumnDifference): row[d_chg.column_name] = d_chg.new_value else: # Handle if we are sent a Mapping instead of a list of ColumnDifference for changes_list # d_chg would then be the Mapping key (column name) # noinspection PyTypeChecker row[d_chg] = changes_list[d_chg] if additional_update_values is not None: for col_name, value in additional_update_values.items(): row[col_name] = value if add_to_cache and self.maintain_cache_during_load: self.cache_row(row, allow_update=True, allow_insert=allow_insert) if not self.in_bulk_mode: # Check that the row isn't pending an insert if row.status != RowStatus.insert: if update_apply_entire_row or changes_list is None: row.status = RowStatus.update_whole self._add_pending_update( update_where_keys=update_where_keys, update_where_values=update_where_values, update_row_new_values=row, parent_stats=stats ) else: # Support partial updates. # We could save significant network transit times in some cases if we didn't send the whole row # back as an update. However, we might also not want a different update statement for each # combination of updated columns. Unfortunately, having code in update_pending_batch scan the # combinations for the most efficient ways to group the updates would also likely be slow. update_columns = {d_chg.column_name for d_chg in changes_list} if additional_update_values is not None: update_columns.update(additional_update_values.keys()) partial_row = row.subset(keep_only=update_columns) partial_row.status = RowStatus.update_partial self._add_pending_update( update_where_keys=update_where_keys, update_where_values=update_where_values, update_row_new_values=partial_row, parent_stats=stats ) if stats is not None: stats['update called'] += 1 else: # The update to the pending row is all we need to do if stats is not None: stats['updated pending insert'] += 1 stats.timer.stop()
def _update_stmt(self) -> UpdateBase: return self.table.update()
[docs] def update_where_pk( self, updates_to_make: Row, key_values: Union[Row, dict, list] = None, source_excludes: Optional[frozenset] = None, target_excludes: Optional[frozenset] = None, stat_name: str = 'update_where_pk', parent_stats: Optional[Statistics] = None, ): """ Updates the table using the primary key for the where clause. Parameters ---------- updates_to_make: Updates to make to the rows matching the criteria Can also be used to pass the key_values, so you can pass a single :class:`~bi_etl.components.row.row_case_insensitive.Row` or ``dict`` to the call and have it automatically get the filter values and updates from it. key_values: Optional. dict or list of key to apply as criteria source_excludes: Optional. set of column names to exclude from the source row source columns to exclude when mapping to this Table. target_excludes: Optional. set of target column names to exclude when mapping from the source to target :class:`~bi_etl.components.row.row_case_insensitive.Row` (s) stat_name: Name of this step for the ETLTask statistics. Default = 'upsert_by_pk' parent_stats: Optional. Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. """ stats = self.get_stats_entry(stat_name, parent_stats=parent_stats) source_mapped_as_target_row = self.build_row( source_row=updates_to_make, source_excludes=source_excludes, target_excludes=target_excludes, parent_stats=stats, ) key_names = self.primary_key key_values_dict, _ = self._generate_key_values_dict(key_names, key_values, other_values_dict=updates_to_make) # Don't use key_names or key_values anymore, use key_values_dict del key_names del key_values source_mapped_as_target_row.update(key_values_dict) self.apply_updates( source_mapped_as_target_row, allow_insert=False, parent_stats=stats, )
[docs] def update( self, updates_to_make: Union[Row, dict], key_names: Optional[Iterable] = None, key_values: Optional[Iterable] = None, lookup_name: Optional[str] = None, update_all_rows: bool = False, source_excludes: Optional[frozenset] = None, target_excludes: Optional[frozenset] = None, stat_name: str = 'direct update', parent_stats: Optional[Statistics] = None, connection_name: Optional[str] = None, ): """ Directly performs a database update. Invalidates caching. THIS METHOD IS SLOW! If you have a full target row, use apply_updates instead. Parameters ---------- updates_to_make: Updates to make to the rows matching the criteria Can also be used to pass the key_values, so you can pass a single :class:`~bi_etl.components.row.row_case_insensitive.Row` or ``dict`` to the call and have it automatically get the filter values and updates from it. key_names: Optional. List of columns to apply criteria too (see ``key_values``). Defaults to Primary Key columns. key_values: Optional. List of values to apply as criteria (see ``key_names``). If not provided, and ``update_all_rows`` is False, look in ``updates_to_make`` for values. lookup_name: str Name of the lookup to use update_all_rows: Optional. Defaults to False. If set to True, ``key_names`` and ``key_values`` are not required. source_excludes Optional. list of :class:`~bi_etl.components.row.row_case_insensitive.Row` source columns to exclude when mapping to this Table. target_excludes Optional. list of Table columns to exclude when mapping from the source :class:`~bi_etl.components.row.row_case_insensitive.Row` (s) stat_name: str Name of this step for the ETLTask statistics. Default = 'direct update' parent_stats: bi_etl.statistics.Statistics Optional. Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. connection_name: Name of the pooled connection to use Defaults to DEFAULT_CONNECTION_NAME """ stats = self.get_stats_entry(stat_name, parent_stats=parent_stats) stats.timer.start() try: _ = updates_to_make.columns_in_order except AttributeError: warnings.warn('Passing types other than Row to update is slower.') updates_to_make = self.row_object(data=updates_to_make, iteration_header=self.empty_iteration_header) # Check if we can pass this of to update_where_pk if key_names is None and not update_all_rows: self.update_where_pk( updates_to_make=updates_to_make, key_values=key_values, source_excludes=source_excludes, target_excludes=target_excludes, parent_stats=stats ) return source_mapped_as_target_row = self.build_row( source_row=updates_to_make, source_excludes=source_excludes, target_excludes=target_excludes, parent_stats=stats, ) stmt = self._update_stmt() if not update_all_rows: key_values_dict, lookup_name = self._generate_key_values_dict( key_names, key_values, lookup_name=lookup_name, other_values_dict=updates_to_make ) # Don't use key_names or key_values anymore, use key_values_dict del key_names del key_values for key_name, key_value in key_values_dict.items(): key = self.get_column(key_name) stmt = stmt.where(key == key_value) # We could optionally scan the entire cache and apply updates there # instead for now, we'll un-cache the row and set the lookups to fall back to the database # TODO: If we have a lookup based on an updated value, the original value would still be # in the lookup. We don't know the original without doing a lookup. self.cache_clean = False if self.maintain_cache_during_load: self.uncache_row(key_values_dict) else: assert not key_values, "update_all_rows set and yet we got key_values" # End not update all rows # Add set statements to the update for c in source_mapped_as_target_row: v = source_mapped_as_target_row[c] stmt = stmt.values({c: v}) self.begin() stats['direct_updates_sent_to_db'] += 1 result = self.execute(stmt, connection_name=connection_name) stats['direct_updates_applied_to_db'] += result.rowcount result.close() stats.timer.stop()
[docs] def update_not_in_set( self, updates_to_make, set_of_key_tuples: set, lookup_name: str = None, criteria_list: list = None, criteria_dict: dict = None, progress_frequency: int = None, stat_name: str = 'update_not_in_set', parent_stats: Statistics = None, **kwargs ): """ Applies update to all rows matching criteria that are not in the list_of_key_tuples pass in. Parameters ---------- updates_to_make: :class:`~bi_etl.components.row.row_case_insensitive.Row` :class:`~bi_etl.components.row.row_case_insensitive.Row` or dict of updates to make set_of_key_tuples List of tuples comprising the primary key values. This list represents the rows that should *not* be updated. lookup_name: str The name of the lookup to use to find key tuples. criteria_list : string or list of strings Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`. https://goo.gl/JlY9us criteria_dict: Dict keys should be columns, values are set using = or in stat_name: string Name of this step for the ETLTask statistics. Default = 'delete_not_in_set' progress_frequency: int How often (in seconds) to output progress messages. Default 10. None for no progress messages. parent_stats: bi_etl.statistics.Statistics Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. """ stats = self.get_unique_stats_entry(stat_name, parent_stats=parent_stats) stats['rows read'] = 0 stats['updates count'] = 0 stats.timer.start() self.begin() if progress_frequency is None: progress_frequency = self.progress_frequency progress_timer = Timer() # Turn off read progress reports saved_progress_frequency = self.progress_frequency self.progress_frequency = None if lookup_name is None: lookup = self.get_nk_lookup() else: lookup = self.get_lookup(lookup_name) if criteria_dict is None: # Default to not processing rows that are already deleted criteria_dict = {self.delete_flag: self.delete_flag_no} else: # Add rule to avoid processing rows that are already deleted criteria_dict[self.delete_flag] = self.delete_flag_no # Note, here we select only lookup columns from self for row in self.where( column_list=lookup.lookup_keys, criteria_list=criteria_list, criteria_dict=criteria_dict, connection_name='select', parent_stats=stats ): if row.status == RowStatus.unknown: pass stats['rows read'] += 1 existing_key = lookup.get_hashable_combined_key(row) if 0 < progress_frequency <= progress_timer.seconds_elapsed: progress_timer.reset() self.log.info( f"update_not_in_set current current row#={stats['rows read']:,} row key={existing_key} updates done so far = {stats['updates count']:,}" ) if existing_key not in set_of_key_tuples: stats['updates count'] += 1 try: # First we need the entire existing row target_row = lookup.find(row) except NoResultFound: raise RuntimeError( f"keys {row.as_key_value_list} found in database or cache but not found now by get_by_lookup" ) # Then we can apply the updates to it self.apply_updates( row=target_row, additional_update_values=updates_to_make, allow_insert=False, parent_stats=stats, ) stats.timer.stop() # Restore saved progress_frequency self.progress_frequency = saved_progress_frequency
[docs] def update_not_processed( self, update_row, lookup_name: Optional[str] = None, criteria_list: Optional[Iterable] = None, criteria_dict: Optional[MutableMapping] = None, use_cache_as_source: bool = True, stat_name: Optional[str] = 'update_not_processed', parent_stats: Optional[Statistics] = None, **kwargs ): """ Applies update to all rows matching criteria that are not in the Table memory of rows passed to :meth:`upsert`. Parameters ---------- update_row: :class:`~bi_etl.components.row.row_case_insensitive.Row` or dict of updates to make criteria_list : Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`. https://goo.gl/JlY9us criteria_dict: Dict keys should be columns, values are set using = or in lookup_name: Optional lookup name those key values are from. use_cache_as_source: Attempt to read existing rows from the cache? stat_name: Name of this step for the ETLTask statistics. parent_stats: Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. kwargs: IF HistoryTable or child thereof: effective_date: datetime The effective date to use for the update """ assert self.track_source_rows, "update_not_processed can't be used if we don't track source rows" if self.source_keys_processed is None or len(self.source_keys_processed) == 0: # We don't want to logically delete all the rows # But that's only an issue if there are target rows if any( True for _ in self.where( criteria_list=criteria_list, criteria_dict=criteria_dict, ) ): raise RuntimeError(f"{stat_name} called before any source rows were processed.") self.update_not_in_set( updates_to_make=update_row, set_of_key_tuples=self.source_keys_processed, criteria_list=criteria_list, criteria_dict=criteria_dict, use_cache_as_source=use_cache_as_source, stat_name=stat_name, parent_stats=parent_stats, **kwargs ) self.source_keys_processed = set()
def _set_upsert_mode(self): if self._bulk_iter_sentinel is not None: raise ValueError( "Upsert called on a table that is using insert_row method. " "Please manually set upsert_called = True before load operations." ) if self.in_bulk_mode and not self.upsert_called: # One time setup pk_lookup = self.get_pk_lookup() pk_lookup.cache_enabled = True self.cache_filled = True self.cache_clean = True self.always_fallback_to_db = False self.upsert_called = True
[docs] def upsert( self, source_row: Union[Row, List[Row]], lookup_name: Optional[str] = None, skip_update_check_on: Optional[frozenset] = None, do_not_update: Optional[Iterable] = None, additional_update_values: Optional[dict] = None, additional_insert_values: Optional[dict] = None, update_callback: Optional[Callable[[list, Row], None]] = None, insert_callback: Optional[Callable[[Row], None]] = None, source_excludes: Optional[frozenset] = None, target_excludes: Optional[frozenset] = None, stat_name: str = 'upsert', parent_stats: Statistics = None, **kwargs ): """ Update (if changed) or Insert a row in the table. This command will look for an existing row in the target table (using the primary key lookup if no alternate lookup name is provided). If no existing row is found, an insert will be generated. If an existing row is found, it will be compared to the row passed in. If changes are found, an update will be generated. Returns the row found/inserted, with the auto-generated key (if that feature is enabled) Parameters ---------- source_row :class:`~bi_etl.components.row.row_case_insensitive.Row` to upsert lookup_name The name of the lookup (see :meth:`define_lookup`) to use when searching for an existing row. skip_update_check_on List of column names to not compare old vs new for updates. do_not_update List of columns to never update. additional_update_values Additional updates to apply when updating additional_insert_values Additional values to set on each row when inserting. update_callback: Function to pass updated rows to. Function should not modify row. insert_callback: Function to pass inserted rows to. Function should not modify row. source_excludes list of Row source columns to exclude when mapping to this Table. target_excludes list of Table columns to exclude when mapping from the source Row(s) stat_name Name of this step for the ETLTask statistics. Default = 'upsert' parent_stats Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. """ stats = self.get_stats_entry(stat_name, parent_stats=parent_stats) stats.ensure_exists('upsert source row count') self._set_upsert_mode() stats.timer.start() self.begin() stats['upsert source row count'] += 1 # Check for existing row source_mapped_as_target_row = self.build_row( source_row=source_row, source_excludes=source_excludes, target_excludes=target_excludes, parent_stats=stats, ) if self.delete_flag is not None: if self.delete_flag not in source_mapped_as_target_row \ or source_mapped_as_target_row[self.delete_flag] is None: source_mapped_as_target_row.set_keeping_parent(self.delete_flag, self.delete_flag_no) try: # We'll default to using the natural key or primary key or NK based on row columns present if lookup_name is None: lookup_name = self.get_default_lookup(source_row.iteration_header) if isinstance(lookup_name, Lookup): lookup_object = lookup_name else: lookup_object = self.get_lookup(lookup_name) if not lookup_object.cache_enabled and self.maintain_cache_during_load and self.batch_size > 1: raise AssertionError("Caching needs to be turned on if batch mode is on!") if self.in_bulk_mode and not (lookup_object.cache_enabled and self.maintain_cache_during_load): raise AssertionError("Caching needs to be turned on if bulk mode is on!") del lookup_object existing_row = self.get_by_lookup(lookup_name, source_mapped_as_target_row, parent_stats=stats) if self.trace_data: lookup_keys = self.get_lookup(lookup_name).get_list_of_lookup_column_values(source_mapped_as_target_row) else: lookup_keys = None if do_not_update is None: do_not_update = set() if skip_update_check_on is None: skip_update_check_on = set() # In case cache was built using an incomplete row, set those missing columns to None before compare cache_missing_columns = self.column_names_set - existing_row.column_set for missing_column in cache_missing_columns: existing_row.set_keeping_parent(missing_column, None) changes_list = existing_row.compare_to( source_mapped_as_target_row, exclude=do_not_update | skip_update_check_on, coerce_types=False, ) conditional_changes = existing_row.compare_to( source_mapped_as_target_row, compare_only=skip_update_check_on, coerce_types=False, ) if self.track_update_columns: col_stats = self.get_stats_entry('updated columns', parent_stats=stats) for chg in changes_list: col_stats.add_to_stat(key=chg.column_name, increment=1) if self.trace_data: self.log.debug( f"{lookup_keys} {chg.column_name} changed from {chg.old_value} to {chg.new_value}" ) if len(changes_list) > 0: conditional_changes_msg = 'Conditional change applied' else: conditional_changes_msg = 'Conditional change NOT applied' for chg in conditional_changes: col_stats.add_to_stat(key=chg.column_name, increment=1) if self.trace_data: self.log.debug( f"{lookup_keys} {conditional_changes_msg}: {chg.column_name} changed from {chg.old_value} to {chg.new_value}" ) if len(changes_list) > 0: changes_list = list(changes_list) + list(conditional_changes) self.apply_updates( row=existing_row, changes_list=changes_list, additional_update_values=additional_update_values, parent_stats=stats ) if update_callback: update_callback(changes_list, existing_row) target_row = existing_row except NoResultFound: new_row = source_mapped_as_target_row if additional_insert_values: for colName, value in additional_insert_values.items(): new_row[colName] = value self.autogenerate_key(new_row, force_override=True) self.insert_row(new_row, parent_stats=stats) stats['insert new called'] += 1 if insert_callback: insert_callback(new_row) target_row = new_row if self.track_source_rows: # Keep track of source records, so we can check if target rows don't exist in source # Note: We use the target_row here since it has already been translated to match the target table # It's also important that it have the existing surrogate key (if any) self.source_keys_processed.add(self.get_natural_key_tuple(target_row)) stats.timer.stop() return target_row
[docs] def upsert_by_pk( self, source_row: Row, stat_name='upsert_by_pk', parent_stats: Optional[Statistics] = None, **kwargs ): """ Used by :meth:`bi_etl.components.table.Table.upsert_special_values_rows` to find and update rows by the full PK. Not expected to be useful outside that use case. Parameters ---------- source_row: :class:`~bi_etl.components.row.row_case_insensitive.Row` Row to upsert stat_name: string Name of this step for the ETLTask statistics. Default = 'upsert_by_pk' parent_stats: bi_etl.statistics.Statistics Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. """ stats = self.get_stats_entry(stat_name, parent_stats=parent_stats) stats.timer.start() self.begin() source_row = self.build_row(source_row) if self.track_source_rows: # Keep track of source records,, so we can check if target rows don't exist in source self.source_keys_processed.add(self.get_natural_key_tuple(source_row)) try: existing_row = self.get_by_key(source_row) changes_list = existing_row.compare_to(source_row, coerce_types=False) if len(changes_list) > 0: self.apply_updates( existing_row, changes_list=changes_list, stat_name=stat_name, parent_stats=stats, **kwargs ) except NoResultFound: self.insert_row(source_row) stats.timer.stop()
[docs] def upsert_special_values_rows( self, stat_name: str = 'upsert_special_values_rows', parent_stats: Optional[Statistics] = None, ): """ Send all special values rows to upsert to ensure they exist and are current. Rows come from :meth:`get_missing_row`, :meth:`get_invalid_row`, :meth:`get_not_applicable_row`, :meth:`get_various_row` Parameters ---------- stat_name: str Name of this step for the ETLTask statistics. Default = 'upsert_special_values_rows' parent_stats: bi_etl.statistics.Statistics Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. """ self.log.info(f"Checking special values rows for {self}") stats = self.get_stats_entry(stat_name, parent_stats=parent_stats) stats['calls'] += 1 stats.timer.start() save_auto_gen = self.auto_generate_key self.auto_generate_key = False special_rows = [ self.get_missing_row(), self.get_invalid_row(), self.get_not_applicable_row(), self.get_various_row(), self.get_none_selected_row() ] for source_row in special_rows: self.upsert(source_row, parent_stats=stats, lookup_name=self._get_pk_lookup_name()) if not self.in_bulk_mode: self.commit(parent_stats=stats) self.auto_generate_key = save_auto_gen stats.timer.stop()
[docs] def truncate( self, timeout: int = 60, stat_name: str = 'truncate', parent_stats: Optional[Statistics] = None, ): """ Truncate the table if possible, else delete all. Parameters ---------- timeout: int How long in seconds to wait for the truncate. Oracle only. stat_name: str Name of this step for the ETLTask statistics. Default = 'truncate' parent_stats: bi_etl.statistics.Statistics Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. """ stats = self.get_stats_entry(stat_name, parent_stats=parent_stats) stats['calls'] += 1 stats.timer.start() database_type = self.database.dialect_name truncate_sql = sqlalchemy.text(f'TRUNCATE TABLE {self.quoted_qualified_table_name}') with self.database.begin(): if database_type == 'oracle': self.execute(f'alter session set ddl_lock_timeout={timeout}') self.execute(truncate_sql) elif database_type in {'mssql', 'mysql', 'postgresql', 'sybase', 'redshift'}: self.execute(truncate_sql) else: self.execute(self._delete_stmt()) stats.timer.stop()
[docs] def transaction( self, connection_name: Optional[str] = None ): connection_name = self.database.resolve_connection_name(connection_name) self._connections_used.add(connection_name) return self.database.begin(connection_name)
[docs] def begin( self, connection_name: Optional[str] = None ): if not self.in_bulk_mode: return self.transaction(connection_name)
[docs] def bulk_load_from_cache( self, temp_table: Optional[str] = None, stat_name: str = 'bulk_load_from_cache', parent_stats: Optional[Statistics] = None, ): if self.in_bulk_mode: if self._bulk_rollback: self.log.info("Not performing bulk load due to rollback") else: if self.bulk_loader is not None: stats = self.get_unique_stats_entry(stat_name, parent_stats=parent_stats) stats.timer.start() assert isinstance( self.bulk_loader, BulkLoader ), f'bulk_loader property needs to be instance of bulk_loader not {type(self.bulk_loader)}' self.close_connections() if self._bulk_iter_sentinel is None: if self.upsert_called: self.log.info(f"Bulk load from lookup cache for {self}") if temp_table is None: temp_table = self.qualified_table_name row_count = self.bulk_loader.load_table_from_cache(self, temp_table) stats['rows_loaded'] = row_count else: self.log.info( f"Bulk for {self} -- nothing to do. Neither Upsert nor insert_row was called." ) else: self.log.info(f"Bulk load from insert_row for {self}") # Finish the iterator on the queue self._bulk_iter_queue.put(self._bulk_iter_sentinel) sleep(1) self._bulk_iter_max_q_length = max(self._bulk_iter_max_q_length, len(self._bulk_iter_queue)) stats['max_insert_queue'] = self._bulk_iter_max_q_length stats['ending_insert_queue'] = len(self._bulk_iter_queue) self.log.info("Waiting for bulk loader") self._bulk_iter_worker.join() stats['rows_loaded'] = self._bulk_iter_worker.value self._bulk_iter_sentinel = None if self._bulk_iter_worker.exception is not None: raise self._bulk_iter_worker.exception stats.timer.stop() self._bulk_load_performed = True else: raise ValueError('bulk_loader not set') else: raise ValueError('bulk_load_from_cache and in_bulk_mode = false')
[docs] def commit( self, stat_name: str = 'commit', parent_stats: Optional[Statistics] = None, print_to_log: bool = True, connection_name: Optional[str] = None, begin_new: bool = True, ): """ Flush any buffered deletes, updates, or inserts Parameters ---------- stat_name: str Name of this step for the ETLTask statistics. Default = 'commit' parent_stats: bi_etl.statistics.Statistics Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. print_to_log: bool Should this add a debug log entry for each commit. Defaults to true. connection_name: Name of the pooled connection to use Defaults to DEFAULT_CONNECTION_NAME begin_new: Start a new transaction after commit """ if self.in_bulk_mode: self.log.debug( f"{self}.commit() does nothing in bulk mode. " f"{self}.bulk_load_from_cache() or {self}.close() will load the pending data." ) # Commit is not final enough to perform the bulk load. else: # insert_pending_batch calls other *_pending methods self._insert_pending_batch() stats = self.get_unique_stats_entry(stat_name, parent_stats=parent_stats) # Start & stop appropriate timers for each stats['commit count'] += 1 stats.timer.start() if connection_name is not None: self.log.debug(f"{self} commit on specified connection {connection_name}") self.database.commit(connection_name) if begin_new: self.begin(connection_name=connection_name) else: # We need to close these in the correct order to prevent deadlocks commit_order = [ # Read only 'select', 'max', 'get_one', # Write 'default', # Used for inserts 'sql_upsert', 'singly', # Really should not be open, it gets rolled back # Any others that exist will be commited last ] self.log.info( f"{self} commit on all active connections " "(if this deadlocks, provide explicit connection names " "to commit in the correct order)." ) for possible_connection_name in commit_order: if possible_connection_name in self._connections_used: self.database.commit(possible_connection_name) if begin_new: self.begin(connection_name=possible_connection_name) for used_connection_name in self._connections_used: if used_connection_name not in commit_order: self.database.commit(used_connection_name) if begin_new: self.begin(connection_name=used_connection_name) self._connections_used = set() stats.timer.stop()
[docs] def rollback( self, stat_name: str = 'rollback', parent_stats: Optional[Statistics] = None, connection_name: Optional[str] = None, begin_new: bool = True, ): """ Rollback any uncommitted deletes, updates, or inserts. Parameters ---------- stat_name: str Name of this step for the ETLTask statistics. Default = 'rollback' parent_stats: bi_etl.statistics.Statistics Optional Statistics object to nest this steps statistics in. Default is to place statistics in the ETLTask level statistics. connection_name: Name of the connection to rollback begin_new: Start a new transaction after rollback """ if self.in_bulk_mode: self._bulk_rollback = True else: self.log.debug("Rolling back transaction") stats = self.get_unique_stats_entry(stat_name, parent_stats=parent_stats) stats['calls'] += 1 stats.timer.start() if connection_name is not None: self.log.debug(f"{self} rollback on specified connection {connection_name}") self.database.rollback(connection_name) self._connections_used.remove(connection_name) else: self.log.debug(f"{self} rollback on all active connections {self._connections_used}") for connection_name in self._connections_used: self.database.rollback(connection_name) self._connections_used = set() if begin_new: self.begin(connection_name=connection_name) stats.timer.stop()
@staticmethod def _sql_column_not_equals(column_name, alias_1='e', alias_2='s'): return "(" \ f"( {alias_1}.{column_name} != {alias_2}.{column_name} )" \ f" OR ({alias_1}.{column_name} IS NULL AND {alias_2}.{column_name} IS NOT NULL)" \ f" OR ({alias_1}.{column_name} IS NOT NULL AND {alias_2}.{column_name} IS NULL)" \ ")" def _safe_temp_table_name(self, proposed_name): max_identifier_length = self.database.bind.dialect.max_identifier_length if len(proposed_name) > max_identifier_length: alpha_code = int2base(abs(hash(proposed_name)), 36) proposed_name = f"tmp{alpha_code}" if len(proposed_name) > max_identifier_length: proposed_name = proposed_name[:max_identifier_length] return proposed_name @staticmethod def _sql_indent_join( statement_list: Iterable[str], indent: int, prefix: str = '', suffix: str = '', add_newline: bool = True, prefix_if_not_empty: str = '', ) -> str: if add_newline: newline = '\n' else: newline = '' statement_list = list(statement_list) if len(statement_list) == 0: prefix_if_not_empty = '' return prefix_if_not_empty + f"{suffix}{newline}{' ' * indent}{prefix}".join(statement_list) def _sql_primary_key_name(self) -> str: if len(self.primary_key) != 1: raise ValueError(f"{self}.upsert_db_exclusive requires single primary_key column. Got {self.primary_key}") return self.primary_key[0] def _sql_now(self) -> str: return "CURRENT_TIMESTAMP" def _sql_add_timedelta( self, column_name: str, delta: timedelta, ) -> str: database_type = self.database.dialect_name if database_type in {'oracle'}: if delta.seconds != 0 or delta.microseconds != 0: return f"{column_name} + INTERVAL '{delta.total_seconds()}' SECOND" elif delta.days != 0: return f"{column_name} + {delta.days}" elif database_type in {'postgresql'}: if delta.seconds != 0 or delta.microseconds != 0: return f"{column_name} + INTERVAL '{delta.total_seconds()} SECONDS'" elif delta.days != 0: return f"{column_name} + INTERVAL '{delta.days} DAY'" elif database_type in {'mysql'}: if delta.seconds != 0: return f"DATE_ADD({column_name}, INTERVAL {delta.total_seconds()} SECOND)" elif delta.microseconds != 0: return f"DATE_ADD({column_name}, INTERVAL {delta.microseconds} MICROSECOND)" elif delta.days != 0: return f"DATE_ADD({column_name}, INTERVAL {delta.days} DAY)" elif database_type in {'sqlite'}: if delta.seconds != 0 or delta.microseconds != 0: return f"datetime({column_name}, '{delta.total_seconds()} SECONDS')" elif delta.days != 0: return f"datetime({column_name}, '{delta.days:+d} DAYS')" else: if delta.seconds != 0 or delta.microseconds != 0: return f"DATEADD(second, {delta.total_seconds()}, {column_name})" elif delta.days != 0: return f"DATEADD(day, {delta.days}, {column_name})" raise ValueError(f"No _sql_add_timedelta defined for DB {database_type} and timedelta {delta}") def _sql_date_literal( self, dt: datetime, ): database_type = self.database.dialect_name if database_type in {'oracle'}: return f"TIMESTAMP '{dt.isoformat(sep=' ')}'" elif database_type in {'postgresql', 'redshift'}: return f"'{dt.isoformat()}'::TIMESTAMP" elif database_type in {'sqlite'}: return f"datetime('{dt.isoformat()}')" else: return f"TIMESTAMP('{dt.isoformat()}')" def _sql_insert_new( self, source_table: ReadOnlyTable, matching_columns: Iterable, effective_date: Optional[datetime] = None, connection_name: str = 'sql_upsert', ): if not self.auto_generate_key: raise ValueError(f"_sql_insert_new expects to only be used with auto_generate_key = True") now_str = self.get_current_time().isoformat() insert_new_sql = f""" INSERT INTO {self.qualified_table_name} ( {self._sql_primary_key_name()}, {self.delete_flag}, {self.last_update_date}, {Table._sql_indent_join(matching_columns, 24, suffix=',')} ) WITH max_srgt AS (SELECT coalesce(max({self._sql_primary_key_name()}),0) as max_srgt FROM {self.qualified_table_name}) SELECT max_srgt.max_srgt + ROW_NUMBER() OVER (order by 1) as {self._sql_primary_key_name()}, '{self.delete_flag_no}' as {self.delete_flag}, '{now_str}' as {self.last_update_date}, {Table._sql_indent_join(matching_columns, 16, suffix=',')} FROM max_srgt CROSS JOIN {source_table.qualified_table_name} s WHERE NOT EXISTS( SELECT 1 FROM {self.qualified_table_name} e WHERE {Table._sql_indent_join([f"e.{nk_col} = s.{nk_col}" for nk_col in self.natural_key], 18, prefix='AND ')} ) """ self.log.debug('=' * 80) self.log.debug('insert_new_sql') self.log.debug('=' * 80) self.log.debug(insert_new_sql) sql_timer = Timer() results = self.execute(insert_new_sql, connection_name=connection_name) self.log.debug(f"Impacted rows = {results.rowcount} (meaningful count? {results.supports_sane_rowcount()})") self.log.debug(f"Execution time ={sql_timer.seconds_elapsed_formatted}") def _sql_update_from( self, update_name: str, source_sql: str, list_of_joins: List[Tuple[str, Tuple[str, str]]], list_of_sets: List[Tuple[str, str]], target_alias_in_source: str = 'tgt', extra_where: Optional[str] = None, connection_name: str = 'sql_upsert', ): # Multiple Table Updates # https://docs.sqlalchemy.org/en/14/core/tutorial.html#multiple-table-updates conn = self.connection(connection_name) database_type = self.database.dialect_name if extra_where is not None: and_extra_where = f"AND {extra_where}" else: extra_where = '' and_extra_where = '' sql = None if database_type in {'oracle'}: sql = f""" MERGE INTO {self.qualified_table_name} tgt USING ( SELECT {Table._sql_indent_join([f"{join_entry[1][0]}.{join_entry[1][1]}" for join_entry in list_of_joins], 24, suffix=',')}, {Table._sql_indent_join([f"{set_entry[1]} as col_{set_num}" for set_num, set_entry in enumerate(list_of_sets)], 16, suffix=',')} FROM {source_sql} ) src ON ( {Table._sql_indent_join([f"tgt.{join_entry[0]} = src.{join_entry[1][1]}" for join_entry in list_of_joins], 18, prefix='AND ')} ) WHEN MATCHED THEN UPDATE SET {Table._sql_indent_join([f"tgt.{set_entry[0]} = src.col_{set_num}" for set_num, set_entry in enumerate(list_of_sets)], 16, suffix=',')} """ elif database_type in {'mysql'}: sql = f""" UPDATE {self.qualified_table_name} INNER JOIN {source_sql} ON {Table._sql_indent_join([f"{self.qualified_table_name}.{join_entry[0]} = {'.'.join(join_entry[1])}" for join_entry in list_of_joins], 19, prefix='AND ')} SET {Table._sql_indent_join([f"{target_alias_in_source}.{set_entry[0]} = {set_entry[1]}" for set_entry in list_of_sets], 16, suffix=',')} """ # SQL Server elif database_type in {'mssql'}: sql = f""" UPDATE {self.qualified_table_name} SET {Table._sql_indent_join([f"{set_entry[0]} = {set_entry[1]}" for set_entry in list_of_sets], 16, suffix=',')} FROM {self.qualified_table_name} INNER JOIN {source_sql} ON {Table._sql_indent_join([f"{self.qualified_table_name}.{join_entry[0]} = {'.'.join(join_entry[1])}" for join_entry in list_of_joins], 24, prefix='AND ')} {and_extra_where} """ elif database_type == 'sqlite': import sqlite3 from packaging import version if version.parse(sqlite3.sqlite_version) > version.parse('3.33.0'): # Use PostgreSQL below pass else: # Older set_statement_list = [] where_stmt = f""" {Table._sql_indent_join([f"{self.qualified_table_name}.{join_entry[0]} = {join_entry[1][0]}.{join_entry[1][1]}" for join_entry in list_of_joins], 18, prefix='AND ')} {and_extra_where} """ for set_entry in list_of_sets: select_part = f""" SELECT {set_entry[1]} FROM {source_sql} WHERE {where_stmt} """ set_statement_list.append(f"{set_entry[0]} = ({select_part})") set_delimiter = ',\n' sql = f""" UPDATE {self.qualified_table_name} SET {set_delimiter.join(set_statement_list)} WHERE EXISTS ({select_part}) """ if sql is None: # SQLite 3.33+ and PostgreSQL # https://sqlite.org/lang_update.html sql = f""" UPDATE {self.qualified_table_name} SET {Table._sql_indent_join([f"{set_entry[0]} = {set_entry[1]}" for set_entry in list_of_sets], 16, suffix=',')} FROM {source_sql} WHERE {Table._sql_indent_join([f"{self.qualified_table_name}.{join_entry[0]} = {join_entry[1][0]}.{join_entry[1][1]}" for join_entry in list_of_joins], 18, prefix='AND ')} {and_extra_where} """ sql_timer = Timer() self.log.debug(f"Running {update_name}") self.log.debug(sql) sql = sqlalchemy.text(sql) results = self.execute(sql, connection_name=connection_name) self.log.debug(f"Impacted rows = {results.rowcount:,} (meaningful count? {results.supports_sane_rowcount()})") self.log.debug(f"Execution time ={sql_timer.seconds_elapsed_formatted}")