"""
Created on Sep 17, 2014
@author: Derek Wood
"""
# https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations
import codecs
import contextlib
import dataclasses
import math
import sys
import textwrap
import traceback
import warnings
from datetime import datetime, date, time, timedelta
from decimal import Decimal
from decimal import InvalidOperation
from typing import *
from typing import Iterable, Callable, List, Union
import sqlalchemy
from gevent import spawn, sleep
from gevent.queue import Queue
from sqlalchemy import CHAR
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql.dml import UpdateBase
from sqlalchemy.sql.expression import bindparam
from sqlalchemy.sql.type_api import TypeEngine
import bi_etl
from bi_etl.bulk_loaders.bulk_loader import BulkLoader
from bi_etl.components.etlcomponent import ETLComponent
from bi_etl.components.get_next_key.local_table_memory import LocalTableMemory
from bi_etl.components.readonlytable import ReadOnlyTable
from bi_etl.components.row.column_difference import ColumnDifference
from bi_etl.components.row.row import Row
from bi_etl.components.row.row_status import RowStatus
from bi_etl.conversions import nvl, int2base
from bi_etl.conversions import replace_tilda
from bi_etl.conversions import round_datetime_ms
from bi_etl.conversions import str2date
from bi_etl.conversions import str2datetime
from bi_etl.conversions import str2decimal
from bi_etl.conversions import str2float
from bi_etl.conversions import str2int
from bi_etl.conversions import str2time
from bi_etl.database import DatabaseMetadata
from bi_etl.exceptions import NoResultFound
from bi_etl.lookups.lookup import Lookup
from bi_etl.scheduler.task import ETLTask
from bi_etl.statement_queue import StatementQueue
from bi_etl.statistics import Statistics
from bi_etl.timer import Timer
from bi_etl.utility import dict_to_str
from bi_etl.utility import get_integer_places
[docs]
@dataclasses.dataclass
class PendingUpdate:
update_where_keys: Tuple[str]
update_where_values: Iterable[Any]
update_row_new_values: Row
# noinspection SqlDialectInspection
[docs]
class Table(ReadOnlyTable):
"""
A class for accessing and updating a table.
Parameters
----------
task : ETLTask
The instance to register in (if not None)
database : bi_etl.scheduler.task.Database
The database to find the table/view in.
table_name : str
The name of the table/view.
exclude_columns
Optional. A list of columns to exclude from the table/view. These columns will not be included in SELECT, INSERT, or UPDATE statements.
Attributes
----------
auto_generate_key: boolean
Should the primary key be automatically generated by the insert/upsert process?
If True, the process will get the current maximum value and then increment it with each insert.
batch_size: int
How many rows should be inserted / updated / deleted in a single batch.
Default = 5000. Assigning None will use the default.
delete_flag : str
The name of the delete_flag column, if any.
(inherited from ReadOnlyTable)
delete_flag_yes : str, optional
The value of delete_flag for deleted rows.
(inherited from ReadOnlyTable)
delete_flag_no : str, optional
The value of delete_flag for *not* deleted rows.
(inherited from ReadOnlyTable)
default_date_format: str
The date parsing format to use for str -> date conversions.
If more than one date format exists in the source, then explicit conversions will be required.
Default = '%m/%d/%Y'
default_date_time_format: Union[str, Iterable[str]]
The date+time parsing format to use for str -> date time conversions.
If more than one date format exists in the source, then explicit conversions will be required.
Default = '%m/%d/%Y %H:%M:%S'
default_time_format: str
The time parsing format to use for str ->time conversions.
If more than one date format exists in the source, then explicit conversions will be required.
Default = '%H:%M:%S'
force_ascii: boolean
Should text values be forced into the ascii character set before passing to the database?
Default = False
last_update_date: str
Name of the column which we should update when table updates are made.
Default = None
log_first_row : boolean
Should we log progress on the first row read. *Only applies if used as a source.*
(inherited from ETLComponent)
max_rows : int, optional
The maximum number of rows to read. *Only applies if Table is used as a source.*
(inherited from ETLComponent)
primary_key
The name of the primary key column(s). Only impacts trace messages. Default=None.
If not passed in, will use the database value, if any.
(inherited from ETLComponent)
progress_frequency : int, optional
How often (in seconds) to output progress messages.
(inherited from ETLComponent)
progress_message : str, optional
The progress message to print.
Default is ``"{logical_name} row # {row_number}"``. Note ``logical_name`` and ``row_number``
substitutions applied via :func:`format`.
(inherited from ETLComponent)
special_values_descriptive_columns
A list of columns that get longer descriptive text in :meth:`get_missing_row`,
:meth:`get_invalid_row`, :meth:`get_not_applicable_row`, :meth:`get_various_row`
(inherited from ReadOnlyTable)
track_source_rows: boolean
Should the :meth:`upsert` method keep a set container of source row keys that it has processed?
That set would then be used by :meth:`update_not_processed`, :meth:`logically_delete_not_processed`,
and :meth:`delete_not_processed`.
"""
DEFAULT_BATCH_SIZE = 5000
# Replacement for float Not a Number (NaN) values
NAN_REPLACEMENT_VALUE = None
from enum import IntEnum, unique
[docs]
@unique
class InsertMethod(IntEnum):
execute_many = 1
insert_values_list = 2
bulk_load = 3
[docs]
@unique
class UpdateMethod(IntEnum):
execute_many = 1
bulk_load = 3
[docs]
@unique
class DeleteMethod(IntEnum):
execute_many = 1
bulk_load = 3
[docs]
def __init__(
self,
task: Optional[ETLTask],
database: DatabaseMetadata,
table_name: str,
table_name_case_sensitive: bool = True,
schema: Optional[str] = None,
exclude_columns: Optional[set] = None,
**kwargs
):
# Don't pass kwargs up. They should be set here at the end
super(Table, self).__init__(
task=task,
database=database,
table_name=table_name,
table_name_case_sensitive=table_name_case_sensitive,
schema=schema,
exclude_columns=exclude_columns,
)
self.support_multiprocessing = False
self._special_row_header = None
self._insert_method = Table.InsertMethod.execute_many
self._update_method = Table.UpdateMethod.execute_many
self._delete_method = Table.DeleteMethod.execute_many
self.bulk_loader = None
self._bulk_rollback = False
self._bulk_load_performed = False
self._bulk_iter_sentinel = None
self._bulk_iter_queue = None
self._bulk_iter_worker = None
self._bulk_iter_max_q_length = 0
self._bulk_defaulted_columns = set()
self.track_update_columns = True
self.track_source_rows = False
self.auto_generate_key = False
self.upsert_called = False
self.last_update_date = None
self.use_utc_times = False
self.default_date_format = '%Y-%m-%d'
self.default_date_time_format = (
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d',
'%m/%d/%Y %H:%M:%S',
'%m/%d/%Y',
)
self.default_time_format = '%H:%M:%S'
self.default_long_text = 'Not Available'
self.default_medium_text = 'N/A'
self.default_char1_text = '?'
self.force_ascii = False
codecs.register_error('replace_tilda', replace_tilda)
self.__batch_size = self.DEFAULT_BATCH_SIZE
self.__transaction_pool = dict()
self.skip_coercion_on = {}
self._logical_delete_update = None
# Safe type mode is slower, but gives better error messages than the database
# that will likely give a not-so helpful message or silently truncate a value.
self.safe_type_mode = True
# Init table "memory"
self.max_keys = dict()
self.max_locally_allocated_keys = dict()
self._table_key_memory = LocalTableMemory(self)
self._max_keys_lock = contextlib.nullcontext()
# Make a self example source row
self._example_row = self.Row()
if self.columns is not None:
for c in self.columns:
self._example_row[c] = None
self.source_keys_processed = set()
self._bind_name_map = None
self.insert_hint = None
# A list of any pending rows to be inserted
self.pending_insert_stats = None
self.pending_insert_rows = list()
# A StatementQueue of any pending delete statements, the queue has the
# statement itself (based on the keys), and a list of pending values
self.pending_delete_stats = None
self.pending_delete_statements = StatementQueue()
# A list of any pending rows to apply as updates
self.pending_update_stats = None
self.pending_update_rows: Dict[tuple, PendingUpdate] = dict()
self._coerce_methods_built = False
# Should be the after all self attributes are created
self.set_kwattrs(**kwargs)
[docs]
def close(self, error: bool = False):
if not self.is_closed:
if self.in_bulk_mode:
if self._bulk_load_performed:
self.log.debug(
f"{self}.close() NOT causing bulk load since a bulk load has already been performed."
)
elif error:
self.log.info(f"{self}.close() NOT causing bulk load due to error.")
else:
self.bulk_load_from_cache()
else:
self._insert_pending_batch()
# Call self.commit() if any connection has has_active_transaction
# but leave it up to commit() to determine the order
any_active_transactions = False
for connection_name in self._connections_used:
if self.database.has_active_transaction(connection_name):
any_active_transactions = True
if any_active_transactions:
self.commit()
super(Table, self).close(error=error)
self.clear_cache()
def __iter__(self) -> Iterable[Row]:
# Note: yield_per will break if the transaction is committed while we are looping
for row in self.where(None):
yield row
@property
def batch_size(self):
return self.__batch_size
@batch_size.setter
def batch_size(self, batch_size):
if batch_size is not None:
if not self.in_bulk_mode:
if batch_size > 0:
self.__batch_size = batch_size
else:
self.__batch_size = 1
@property
def table_key_memory(self):
return self._table_key_memory
@table_key_memory.setter
def table_key_memory(self, table_key_memory):
self._table_key_memory = table_key_memory
@property
def in_bulk_mode(self):
return self._insert_method == Table.InsertMethod.bulk_load
[docs]
def set_bulk_loader(
self,
bulk_loader: BulkLoader
):
self.log.info(f'Changing {self} to bulk load method')
self.__batch_size = sys.maxsize
self._insert_method = Table.InsertMethod.bulk_load
self._update_method = Table.UpdateMethod.bulk_load
self._delete_method = Table.DeleteMethod.bulk_load
self.bulk_loader = bulk_loader
[docs]
def cache_row(
self,
row: Row,
allow_update: bool = False,
allow_insert: bool = True,
):
if self.in_bulk_mode:
self._bulk_load_performed = False
if self.bulk_loader.needs_all_columns:
if row.column_set != self.column_names_set:
# Add missing columns setting to default value
for column_name in self.column_names_set - row.column_set:
column = self.get_column(column_name)
default = column.default
if not column.nullable and default is None:
if column.type.python_type == str:
col_len = column.type.length
if col_len > 4 and self.database.uses_bytes_length_limits:
col_len = int(col_len / 4)
if col_len is None:
col_len = 4000
if col_len >= len(self.default_long_text):
default = self.default_long_text
elif col_len >= len(self.default_medium_text):
default = self.default_medium_text
else:
default = self.default_char1_text
row.set_keeping_parent(column_name, default)
if column_name not in self._bulk_defaulted_columns:
self.log.warning(f'defaulted column {column_name} to {default}')
self._bulk_defaulted_columns.add(column_name)
super().cache_row(
row,
allow_update=allow_update,
allow_insert=allow_insert
)
@property
def insert_method(self):
return self._insert_method
@insert_method.setter
def insert_method(self, value):
self._insert_method = value
if value == Table.InsertMethod.bulk_load:
raise ValueError('Do not manually set bulk mode property. Use set_bulk_loader instead.')
@property
def update_method(self):
return self._update_method
@update_method.setter
def update_method(self, value):
self._update_method = value
if value == Table.UpdateMethod.bulk_load:
raise ValueError('Do not manually set bulk mode property. Use set_bulk_loader instead.')
@property
def delete_method(self):
return self._delete_method
@delete_method.setter
def delete_method(self, value):
self._delete_method = value
if value == Table.DeleteMethod.bulk_load:
raise ValueError('Do not manually set bulk mode property. Use set_bulk_loader instead.')
[docs]
def autogenerate_sequence(
self,
row: Row,
seq_column: str,
force_override: bool = True,
):
# Make sure we have a column object
seq_column_obj = self.get_column(seq_column)
# If key value is not already set, or we are supposed to force override
if row.get(seq_column_obj.name) is None or force_override:
next_key = self.table_key_memory.get_next_key(seq_column)
row.set_keeping_parent(seq_column_obj.name, next_key)
return next_key
[docs]
def autogenerate_key(
self,
row: Row,
force_override: bool = True,
):
if self.auto_generate_key:
if self.primary_key is None:
raise ValueError("No primary key set")
pk_list = list(self.primary_key)
if len(pk_list) > 1:
raise ValueError(
f"Can't auto generate a compound key with table {self} pk={self.primary_key}"
)
key = pk_list[0]
return self.autogenerate_sequence(row, seq_column=key, force_override=force_override)
# noinspection PyUnresolvedReferences
def _trace_data_type(
self,
target_name: str,
t_type: TypeEngine,
target_column_value: object,
):
try:
self.log.debug(f"{target_name} t_type={t_type}")
self.log.debug(f"{target_name} t_type.precision={t_type.precision}")
self.log.debug(f"{target_name} target_column_value={target_column_value}")
self.log.debug(
f"{target_name} get_integer_places(target_column_value)={get_integer_places(target_column_value)}"
)
self.log.debug(
f"{target_name} (t_type.precision - t_type.scale)={(nvl(t_type.precision, 0) - nvl(t_type.scale, 0))}"
)
except AttributeError as e:
self.log.error(traceback.format_exc())
self.log.debug(repr(e))
def _generate_null_check(
self,
target_column_object: ColumnElement
) -> str:
target_name = target_column_object.name
code = ''
if not target_column_object.nullable:
if not (self.auto_generate_key and target_name in self.primary_key):
# Check for nulls. Not as an ELSE because the conversion logic might have made the value null
code = textwrap.dedent(
f"""\
# base indent
if target_column_value is None:
msg = "{self}.{target_name} has is not nullable and this cannot accept value '{{val}}'".format(
val=target_column_value,
)
raise ValueError(msg)
"""
)
return code
def _get_coerce_method_name_by_str(self, target_column_name: str) -> str:
if not self._coerce_methods_built:
self._build_coerce_methods()
return f"_coerce_{target_column_name}"
def _get_coerce_method_name_by_object(
self,
target_column_object: Union[str, 'sqlalchemy.sql.expression.ColumnElement']
) -> str:
target_column_name = self.get_column_name(target_column_object)
return self._get_coerce_method_name_by_str(target_column_name)
[docs]
def get_coerce_method(
self,
target_column_object: Union[str, 'sqlalchemy.sql.expression.ColumnElement']
) -> Callable:
if not self._coerce_methods_built:
self._build_coerce_methods()
method_name = self._get_coerce_method_name_by_object(target_column_object)
try:
return getattr(self, method_name)
except AttributeError:
raise AttributeError(
f'{self} does not have a coerce method for {target_column_object} '
f'- check that vs columns list {self.column_names}'
)
def _make_generic_coerce(
self,
target_column_object: ColumnElement,
):
name = self._get_coerce_method_name_by_object(target_column_object)
code = f"def {name}(self, target_column_value):"
code += self._generate_null_check(target_column_object)
code += textwrap.dedent(
"""
# base indent
return target_column_value
"""
)
try:
exec(code)
except SyntaxError as e:
self.log.exception(f"{e} from code\n{code}")
# Add the new function as a method in this class
exec(f"self.{name} = {name}.__get__(self)")
def _make_str_coerce(
self,
target_column_object: ColumnElement,
):
target_name = target_column_object.name
t_type = target_column_object.type
name = self._get_coerce_method_name_by_object(target_column_object)
code = f"def {name}(self, target_column_value):"
code += textwrap.dedent(
"""\
# base indent
if isinstance(target_column_value, str):
"""
)
if self.force_ascii:
# Passing ascii bytes to cx_Oracle is not working.
# We need to pass a str value.
# So we'll use encode with 'replace' to force ascii compatibility
code += textwrap.dedent(
"""\
# base indent
target_column_value = \
target_column_value.encode('ascii', 'replace_tilda').decode('ascii')
"""
)
else:
code += textwrap.dedent(
"""\
# base indent
pass
"""
)
code += textwrap.dedent(
"""\
# base indent
elif isinstance(target_column_value, bytes):
target_column_value = target_column_value.decode('ascii')
elif target_column_value is None:
return None
else:
target_column_value = str(target_column_value)
"""
)
# Note: t_type.length is None for CLOB fields
if t_type.length is not None:
try:
if t_type.length > 0:
if self.database.uses_bytes_length_limits:
# Encode the str as uft-8 to get the byte length
code += textwrap.dedent(
f"""\
# base indent
value_len = len(target_column_value.encode('utf-8'))
if value_len > {t_type.length}:
msg = ("{self}.{target_name} has type {str(t_type).replace('"', "'")} "
f"which cannot accept value '{{target_column_value}}' because "
f"byte length of {{len(target_column_value.encode('utf-8'))}} > {t_type.length} limit (_make_str_coerce)"
f"Note: char length is {{len(target_column_value)}}"
)
raise ValueError(msg)
"""
)
else:
code += textwrap.dedent(
f"""\
# base indent
value_len = len(target_column_value)
if value_len > {t_type.length}:
msg = ("{self}.{target_name} has type {str(t_type).replace('"', "'")} "
f"which cannot accept value '{{target_column_value}}' because "
f"char length of {{len(target_column_value)}} > {t_type.length} limit (_make_str_coerce)"
)
raise ValueError(msg)
"""
)
except TypeError:
# t_type.length is not a comparable type
pass
if isinstance(t_type, CHAR):
code += textwrap.dedent(
f"""\
# base indent
if value_len < {t_type.length}:
target_column_value += ' ' * ({t_type.length} - len(target_column_value))
"""
)
code += str(self._generate_null_check(target_column_object))
code += textwrap.dedent(
"""
# base indent
return target_column_value
"""
)
try:
exec(code)
except SyntaxError as e:
self.log.exception(f"{e} from code\n{code}")
# Add the new function as a method in this class
exec(f"self.{name} = {name}.__get__(self)")
def _make_bytes_coerce(
self,
target_column_object: ColumnElement,
):
target_name = target_column_object.name
t_type = target_column_object.type
name = self._get_coerce_method_name_by_object(target_column_object)
code = f"def {name}(self, target_column_value):"
code += textwrap.dedent(
"""\
# base indent
if isinstance(target_column_value, bytes):
pass
elif isinstance(target_column_value, str):
target_column_value = target_column_value.encode('utf-8')
elif target_column_value is None:
return None
else:
target_column_value = str(target_column_value).encode('utf-8')
"""
)
# t_type.length is None for BLOB, LargeBinary fields.
# This really might not be required since all
# discovered types with python_type == bytes:
# have no length
if t_type.length is not None:
try:
if t_type.length > 0:
code += textwrap.dedent(
"""\
# base indent
if len(target_column_value) > {len}:
msg = "{table}.{column} has type {type} which cannot accept value '{{val}}' because length {{val_len}} > {len} limit (_make_bytes_coerce)"
msg = msg.format(
val=target_column_value,
val_len=len(target_column_value),
)
raise ValueError(msg)
"""
).format(
len=t_type.length,
table=self,
column=target_name,
type=str(t_type).replace('"', "'"),
)
except TypeError:
# t_type.length is not a comparable type
pass
code += self._generate_null_check(target_column_object)
code += textwrap.dedent(
"""
# base indent
return target_column_value
"""
)
try:
exec(code)
except SyntaxError as e:
self.log.exception(f"{e} from code\n{code}")
# Add the new function as a method in this class
exec(f"self.{name} = {name}.__get__(self)")
def _make_int_coerce(
self,
target_column_object: ColumnElement,
):
name = self._get_coerce_method_name_by_object(target_column_object)
code = f"def {name}(self, target_column_value):"
code += textwrap.dedent(
"""\
# base indent
try:
if isinstance(target_column_value, int):
pass
elif target_column_value is None:
return None
elif isinstance(target_column_value, str):
target_column_value = str2int(target_column_value)
elif isinstance(target_column_value, float):
target_column_value = int(target_column_value)
except ValueError as e:
msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_int_coerce)".format(
val=target_column_value,
e=e,
)
raise ValueError(msg)
"""
).format(table=self, column=target_column_object.name)
code += self._generate_null_check(target_column_object)
code += textwrap.dedent(
"""
# base indent
return target_column_value
"""
)
try:
exec(code)
except SyntaxError as e:
self.log.exception(f"{e} from code\n{code}")
# Add the new function as a method in this class
exec(f"self.{name} = {name}.__get__(self)")
def _make_float_coerce(
self,
target_column_object: ColumnElement,
):
name = self._get_coerce_method_name_by_object(target_column_object)
code = f"def {name}(self, target_column_value):"
# Note: str2float takes 635 ns vs 231 ns for float() but handles commas and signs.
# The thought is that ETL jobs that need the performance and can guarantee no commas
# can explicitly use float
code += textwrap.dedent(
"""\
# base indent
try:
if isinstance(target_column_value, float):
if math.isnan(target_column_value):
target_column_value = self.NAN_REPLACEMENT_VALUE
elif target_column_value is None:
return None
elif isinstance(target_column_value, str):
target_column_value = str2float(target_column_value)
elif isinstance(target_column_value, int):
target_column_value = float(target_column_value)
elif isinstance(target_column_value, Decimal):
if math.isnan(target_column_value):
target_column_value = self.NAN_REPLACEMENT_VALUE
else:
target_column_value = float(target_column_value)
except ValueError as e:
msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_float_coerce)".format(
val=target_column_value,
e=e,
)
raise ValueError(msg)
"""
).format(table=self, column=target_column_object.name)
code += self._generate_null_check(target_column_object)
code += textwrap.dedent(
"""
# base indent
return target_column_value
"""
)
try:
exec(code)
except SyntaxError as e:
self.log.exception(f"{e} from code\n{code}")
# Add the new function as a method in this class
exec(f"self.{name} = {name}.__get__(self)")
def _make_decimal_coerce(
self,
target_column_object: ColumnElement,
):
target_name = target_column_object.name
t_type = target_column_object.type
name = self._get_coerce_method_name_by_object(target_column_object)
code = f"def {name}(self, target_column_value):"
code += textwrap.dedent(
"""\
# base indent
if isinstance(target_column_value, Decimal):
pass
elif isinstance(target_column_value, float):
pass
elif target_column_value is None:
return None
"""
)
# If for performance reasons you don't want this conversion...
# DON'T send in a string!
# str2decimal takes 765 ns vs 312 ns for Decimal() but handles commas and signs.
# The thought is that ETL jobs that need the performance and can
# guarantee no commas can explicitly use float or Decimal
code += textwrap.dedent(
"""\
# base indent
elif isinstance(target_column_value, str):
try:
target_column_value = str2decimal(target_column_value)
except ValueError as e:
msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_decimal_coerce)".format(
val=target_column_value,
e=e,
)
raise ValueError(msg)
"""
).format(table=self, column=target_column_object.name)
# t_type.length is None for BLOB, LargeBinary fields.
# This really might not be required since all
# discovered types with python_type == bytes:
# have no length
if t_type.precision is not None:
scale = nvl(t_type.scale, 0)
integer_digits_allowed = t_type.precision - scale
code += textwrap.dedent(
"""\
# base indent
if target_column_value is not None:
digits = get_integer_places(target_column_value)
if digits > {integer_digits_allowed}:
msg = "{table}.{column} can't accept '{{val}}' since it has {{digits}} integer digits (_make_decimal_coerce)"\
"which is > {integer_digits_allowed} by (prec {precision} - scale {scale}) limit".format(
val=target_column_value,
digits=digits,
)
raise ValueError(msg)
"""
).format(
table=self,
column=target_name,
integer_digits_allowed=integer_digits_allowed,
precision=t_type.precision,
scale=scale
)
code += self._generate_null_check(target_column_object)
code += textwrap.dedent(
"""
# base indent
return target_column_value
"""
)
try:
exec(code)
except SyntaxError as e:
self.log.exception(f"{e} from code\n{code}")
# Add the new function as a method in this class
exec(f"self.{name} = {name}.__get__(self)")
def _make_date_coerce(
self,
target_column_object: ColumnElement,
):
name = self._get_coerce_method_name_by_object(target_column_object)
code = f"def {name}(self, target_column_value):"
# Note: str2float takes 635 ns vs 231 ns for float() but handles commas and signs.
# The thought is that ETL jobs that need the performance and can guarantee no commas
# can explicitly use float
code += textwrap.dedent(
"""\
# base indent
try:
# Note datetime check must be 1st because datetime tests as an instance of date
if isinstance(target_column_value, datetime):
target_column_value = date(target_column_value.year,
target_column_value.month,
target_column_value.day)
elif isinstance(target_column_value, date):
pass
elif target_column_value is None:
return None
elif isinstance(target_column_value, str):
target_column_value = str2date(target_column_value, dt_format=self.default_date_format)
else:
target_column_value = str2date(str(target_column_value), dt_format=self.default_date_format)
except ValueError as e:
msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_date_coerce) {{fmt}}".format(
val=target_column_value,
e=e,
fmt=self.default_date_format,
)
raise ValueError(msg)
"""
).format(table=self, column=target_column_object.name)
code += self._generate_null_check(target_column_object)
code += textwrap.dedent(
"""
# base indent
return target_column_value
"""
)
try:
exec(code)
except SyntaxError as e:
self.log.exception(f"{e} from code\n{code}")
# Add the new function as a method in this class
exec(f"self.{name} = {name}.__get__(self)")
def _make_datetime_coerce(
self,
target_column_object: ColumnElement,
):
name = self._get_coerce_method_name_by_object(target_column_object)
code = f"def {name}(self, target_column_value):"
code += textwrap.dedent(
"""\
# base indent
try:
if isinstance(target_column_value, datetime):
pass
elif isinstance(target_column_value, timedelta):
pass
elif target_column_value is None:
return None
elif isinstance(target_column_value, date):
target_column_value = datetime.combine(target_column_value, time.min)
elif isinstance(target_column_value, str):
target_column_value = str2datetime(target_column_value,
dt_format=self.default_date_time_format)
else:
target_column_value = str2datetime(str(target_column_value),
dt_format=self.default_date_time_format)
except ValueError as e:
msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_datetime_coerce) {{fmt}}".format(
val=target_column_value,
e=e,
fmt=self.default_date_time_format,
)
raise ValueError(msg)
"""
).format(table=self, column=target_column_object.name)
code += self._generate_null_check(target_column_object)
if self.table.bind.dialect.dialect_description == 'mssql+pyodbc':
# fast_executemany Currently causes this error on datetime update (dimension load)
# [Microsoft][ODBC Driver 17 for SQL Server]Datetime field overflow. Fractional second precision exceeds the scale specified in the parameter binding. (0)
# Also see https://github.com/sqlalchemy/sqlalchemy/issues/4418
# All because SQL Server DATETIME values are limited to 3 digits
# Check for datetime2 and don't do this!
if str(target_column_object.type) == 'DATETIME':
self.log.warning(f"Rounding microseconds on {target_column_object}")
code += textwrap.dedent(
"""
# base indent
target_column_value = round_datetime_ms(target_column_value, 3)
"""
)
code += textwrap.dedent(
"""
# base indent
return target_column_value
"""
)
try:
code = compile(code, filename='_make_datetime_coerce', mode='exec')
exec(code)
except SyntaxError as e:
self.log.exception(f"{e} from code\n{code}")
# Add the new function as a method in this class
exec(f"self.{name} = {name}.__get__(self)")
def _make_time_coerce(
self,
target_column_object: ColumnElement,
):
name = self._get_coerce_method_name_by_object(target_column_object)
code = f"def {name}(self, target_column_value):"
# Note: str2float takes 635 ns vs 231 ns for float() but handles commas and signs.
# The thought is that ETL jobs that need the performance and can guarantee no commas
# can explicitly use float
code += textwrap.dedent(
"""\
# base indent
try:
if isinstance(target_column_value, time):
pass
elif target_column_value is None:
return None
elif isinstance(target_column_value, datetime):
target_column_value = time(target_column_value.hour,
target_column_value.minute,
target_column_value.second,
target_column_value.microsecond,
target_column_value.tzinfo,
)
elif isinstance(target_column_value, str):
target_column_value = str2time(target_column_value, dt_format=self.default_time_format)
else:
target_column_value = str2time(str(target_column_value), dt_format=self.default_time_format)
except ValueError as e:
msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_time_coerce)".format(
val=target_column_value,
e=e,
)
raise ValueError(msg)
"""
).format(table=self, column=target_column_object.name)
code += self._generate_null_check(target_column_object)
code += textwrap.dedent(
"""
# base indent
return target_column_value
"""
)
try:
exec(code)
except SyntaxError as e:
self.log.exception(f"{e} from code\n{code}")
# Add the new function as a method in this class
exec(f"self.{name} = {name}.__get__(self)")
def _make_timedelta_coerce(
self,
target_column_object: ColumnElement,
):
name = self._get_coerce_method_name_by_object(target_column_object)
code = f"def {name}(self, target_column_value):"
# Note: str2float takes 635 ns vs 231 ns for float() but handles commas and signs.
# The thought is that ETL jobs that need the performance and can guarantee no commas
# can explicitly use float
code += textwrap.dedent(
"""\
# base indent
try:
if isinstance(target_column_value, timedelta):
pass
elif target_column_value is None:
return None
elif isinstance(target_column_value, float):
if math.isnan(target_column_value):
target_column_value = self.NAN_REPLACEMENT_VALUE
target_column_value = timedelta(seconds=target_column_value)
elif isinstance(target_column_value, str):
target_column_value = str2float(target_column_value)
target_column_value = timedelta(seconds=target_column_value)
elif isinstance(target_column_value, int):
target_column_value = timedelta(seconds=target_column_value)
else:
target_column_value = timedelta(seconds=target_column_value)
except (TypeError, ValueError) as e:
msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_timedelta_coerce)".format(
val=target_column_value,
e=e,
)
raise ValueError(msg)
"""
).format(table=self, column=target_column_object.name)
code += self._generate_null_check(target_column_object)
code += textwrap.dedent(
"""
# base indent
return target_column_value
"""
)
try:
exec(code)
except SyntaxError as e:
self.log.exception(f"{e} from code\n{code}")
# Add the new function as a method in this class
exec(f"self.{name} = {name}.__get__(self)")
def _make_bool_coerce(self, target_column_object: 'sqlalchemy.sql.expression.ColumnElement'):
target_name = target_column_object.name
name = self._get_coerce_method_name_by_object(target_column_object)
code = f"def {name}(self, target_column_value):"
# Note: str2float takes 635 ns vs 231 ns for float() but handles commas and signs.
# The thought is that ETL jobs that need the performance and can guarantee no commas
# can explicitly use float
code += textwrap.dedent(
"""\
# base indent
try:
if isinstance(target_column_value, bool):
pass
elif target_column_value is None:
return None
elif isinstance(target_column_value, str):
target_column_value = target_column_value.lower()
if target_column_value in ['true', 'yes', 'y']:
target_column_value = True
elif target_column_value in ['false', 'no', 'n']:
target_column_value = True
else:
type_error = True
msg = "{table}.{column} unexpected value {{val}} (expected true/false, yes/no, y/n)".format(
val=target_column_value,
)
raise ValueError(msg)
else:
target_column_value = bool(target_column_value)
except (TypeError, ValueError) as e:
msg = "{table}.{column} can't accept '{{val}}' due to {{e}} (_make_bool_coerce)".format(
val=target_column_value,
e=e,
)
raise ValueError(msg)
"""
).format(table=self, column=target_name, )
code += self._generate_null_check(target_column_object)
code += textwrap.dedent(
"""
# base indent
return target_column_value
"""
)
try:
code = compile(code, filename='_make_bool_coerce', mode='exec')
exec(code)
except SyntaxError as e:
self.log.exception(f"{e} from code\n{code}")
# Add the new function as a method in this class
exec(f"self.{name} = {name}.__get__(self)")
def _build_coerce_methods(self):
self._coerce_methods_built = True
for column in self.columns:
if column.name in self.skip_coercion_on:
self._make_generic_coerce(column)
else:
t_type = column.type
try:
if t_type.python_type == str:
self._make_str_coerce(column)
elif t_type.python_type == bytes:
self._make_bytes_coerce(column)
elif t_type.python_type == int:
self._make_int_coerce(column)
elif t_type.python_type == float:
self._make_float_coerce(column)
elif t_type.python_type == Decimal:
self._make_decimal_coerce(column)
elif t_type.python_type == date:
self._make_date_coerce(column)
elif t_type.python_type == datetime:
self._make_datetime_coerce(column)
elif t_type.python_type == time:
self._make_time_coerce(column)
elif t_type.python_type == timedelta:
self._make_timedelta_coerce(column)
elif t_type.python_type == bool:
self._make_bool_coerce(column)
else:
warnings.warn(f'Table.build_row has no handler for {t_type} = {type(t_type)}')
self._make_generic_coerce(column)
except NotImplementedError:
raise ValueError(f"Column {column} has type {t_type} with no python_type")
[docs]
def column_coerce_type(
self,
target_column_object: ColumnElement,
target_column_value: object,
):
"""
This is the slower non-dynamic code based data type conversion routine
"""
target_name = target_column_object.name
if target_column_value is not None:
t_type = target_column_object.type
type_error = False
err_msg = None
try:
if t_type.python_type == str:
if isinstance(target_column_value, str):
if self.force_ascii:
# Passing ascii bytes to cx_Oracle is not working.
# We need to pass a str value.
# So we'll use encode with 'replace' to force ascii compatibility
target_column_value = \
target_column_value.encode('ascii', 'replace_tilda').decode('ascii')
elif isinstance(target_column_value, bytes):
target_column_value = target_column_value.decode('ascii')
else:
target_column_value = str(target_column_value)
# Note: t_type.length is None for CLOB fields
try:
if t_type.length is not None and len(target_column_value) > t_type.length:
type_error = True
err_msg = f"length {len(target_column_value)} > {t_type.length} limit"
except TypeError:
# t_type.length is not a comparable type
pass
elif t_type.python_type == bytes:
if isinstance(target_column_value, str):
target_column_value = target_column_value.encode('utf-8')
elif isinstance(target_column_value, bytes):
pass
# We don't update the value
else:
target_column_value = str(target_column_value).encode('utf-8')
# t_type.length is None for BLOB, LargeBinary fields.
# This really might not be required since all
# discovered types with python_type == bytes:
# have no length
if t_type.length is not None:
if len(target_column_value) > t_type.length:
type_error = True
err_msg = f"{len(target_column_value)} > {t_type.length}"
elif t_type.python_type == int:
if isinstance(target_column_value, str):
# Note: str2int takes 590 ns vs 220 ns for int() but handles commas and signs.
target_column_value = str2int(target_column_value)
elif t_type.python_type == float:
# noinspection PyTypeChecker
if isinstance(target_column_value, str):
# Note: str2float takes 635 ns vs 231 ns for float() but handles commas and signs.
# The thought is that ETL jobs that need the performance and can guarantee no commas
# can explicitly use float
target_column_value = str2float(target_column_value)
elif math.isnan(float(target_column_value)):
target_column_value = self.NAN_REPLACEMENT_VALUE
elif t_type.python_type == Decimal:
if isinstance(target_column_value, str):
# If for performance reasons you don't want this conversion...
# DON'T send in a string!
# str2decimal takes 765 ns vs 312 ns for Decimal() but handles commas and signs.
# The thought is that ETL jobs that need the performance and can
# guarantee no commas can explicitly use float or Decimal
target_column_value = str2decimal(target_column_value)
elif isinstance(target_column_value, float):
if math.isnan(target_column_value):
target_column_value = self.NAN_REPLACEMENT_VALUE
if t_type.precision is not None:
scale = nvl(t_type.scale, 0)
if get_integer_places(target_column_value) > (t_type.precision - scale):
type_error = True
err_msg = (
f"{get_integer_places(target_column_value)} "
f"digits > {(t_type.precision - scale)} = "
f"(prec {t_type.precision} - scale {t_type.scale}) limit"
)
elif t_type.python_type == date:
# If we already have a datetime, make it a date
if isinstance(target_column_value, datetime):
target_column_value = date(
target_column_value.year,
target_column_value.month,
target_column_value.day
)
# If we already have a date
elif isinstance(target_column_value, date):
pass
else:
target_column_value = str2date(
str(target_column_value),
dt_format=self.default_date_format
)
elif t_type.python_type == datetime:
# If we already have a date or datetime value
if isinstance(target_column_value, datetime):
if 'mssql' in self.table.bind.dialect.dialect_description:
# fast_executemany Currently causes this error on datetime update (dimension load)
# [Microsoft][ODBC Driver 17 for SQL Server]Datetime field overflow. Fractional second precision exceeds the scale specified in the parameter binding. (0)
# Also see https://github.com/sqlalchemy/sqlalchemy/issues/4418
# All because SQL Server DATETIME values are limited to 3 digits
# Check for datetime2 and don't do this!
if str(target_column_object.type) == 'DATETIME':
target_column_value = round_datetime_ms(target_column_value, 3)
elif isinstance(target_column_value, str):
target_column_value = str2datetime(
target_column_value,
dt_format=self.default_date_time_format
)
else:
target_column_value = str2datetime(
str(target_column_value),
dt_format=self.default_date_time_format
)
elif t_type.python_type == time:
# If we already have a datetime, make it a time
if isinstance(target_column_value, datetime):
target_column_value = time(
target_column_value.hour,
target_column_value.minute,
target_column_value.second,
target_column_value.microsecond,
target_column_value.tzinfo,
)
# If we already have a date or time value
elif isinstance(target_column_value, time):
pass
else:
target_column_value = str2time(
str(target_column_value),
dt_format=self.default_time_format
)
elif t_type.python_type == timedelta:
# If we already have an interval value
if isinstance(target_column_value, timedelta):
pass
else:
# noinspection PyTypeChecker
target_column_value = timedelta(seconds=float(target_column_value))
elif t_type.python_type == bool:
if isinstance(target_column_value, bool):
pass
elif isinstance(target_column_value, str):
target_column_value = target_column_value.lower()
if target_column_value in ['true', 'yes', 'y']:
target_column_value = True
elif target_column_value in ['false', 'no', 'n']:
target_column_value = True
else:
type_error = True
err_msg = "unexpected value (expected true/false, yes/no, y/n)"
else:
warnings.warn(f'Table.build_row has no handler for {t_type} = {type(t_type)}')
except (ValueError, InvalidOperation) as e:
self.log.error(repr(e))
err_msg = repr(e)
type_error = True
if type_error:
msg = (
f"{self}.{target_name} has type {t_type} "
f"which cannot accept value '{target_column_value}' {err_msg}"
)
raise ValueError(msg)
# Check for nulls. Not as an ELSE because the conversion logic might have made the value null
if target_column_value is None:
if not target_column_object.nullable:
if not self.auto_generate_key and target_name in self.primary_key:
msg = (
f"{self}.{target_name} has is not nullable "
f"and this cannot accept value '{target_column_value}'"
)
raise ValueError(msg)
# End if target_column_value is not None:
return target_column_value
@staticmethod
def _get_build_row_name(row: Row) -> str:
return f"_build_row_{row.iteration_header.iteration_id}"
[docs]
def sanity_check_source_mapping(
self,
source_definition: ETLComponent,
source_name: str = None,
source_excludes: Optional[frozenset] = None,
target_excludes: Optional[frozenset] = None,
ignore_source_not_in_target: Optional[bool] = None,
ignore_target_not_in_source: Optional[bool] = None,
raise_on_source_not_in_target: Optional[bool] = None,
raise_on_target_not_in_source: Optional[bool] = None,
):
if target_excludes is None:
target_excludes = set()
else:
target_excludes = set(target_excludes)
if self.auto_generate_key:
if self.primary_key is not None:
# Don't complain about the PK column being missing
# Note if PK is more than one column it will fail in autogenerate_key. Doesn't need to be caught here.
key = self.get_column_name(list(self.primary_key)[0])
if key not in source_definition:
target_excludes.add(key)
else:
raise KeyError('Cannot generate key values without a primary key.')
if self.delete_flag is not None and self.delete_flag not in source_definition:
target_excludes.add(self.delete_flag)
if self.last_update_date is not None and self.last_update_date not in source_definition:
target_excludes.add(self.last_update_date)
target_excludes = frozenset(target_excludes)
super().sanity_check_source_mapping(
source_definition=source_definition,
source_name=source_name,
source_excludes=source_excludes,
target_excludes=target_excludes,
ignore_source_not_in_target=ignore_source_not_in_target,
ignore_target_not_in_source=ignore_target_not_in_source,
raise_on_source_not_in_target=raise_on_source_not_in_target,
raise_on_target_not_in_source=raise_on_target_not_in_source,
)
[docs]
def sanity_check_example_row(
self,
example_source_row,
source_excludes: Optional[frozenset] = None,
target_excludes: Optional[frozenset] = None,
ignore_source_not_in_target: Optional[bool] = None,
ignore_target_not_in_source: Optional[bool] = None,
):
self.sanity_check_source_mapping(
example_source_row,
example_source_row.name,
source_excludes=source_excludes,
target_excludes=target_excludes,
ignore_source_not_in_target=ignore_source_not_in_target,
ignore_target_not_in_source=ignore_target_not_in_source,
)
# TODO: Sanity check primary key data types.
# Lookups might fail if the types don't match (although build_row in safe_mode should fix it)
def _insert_stmt(self):
# pylint: disable=no-value-for-parameter
ins = self.table.insert()
if self.insert_hint is not None:
ins.with_hint(self.insert_hint)
return ins
def _insert_pending_batch(
self,
stat_name: str = 'insert',
parent_stats: Optional[Statistics] = None,
):
if self.in_bulk_mode:
raise InvalidOperation('_insert_pending_batch is not allowed for bulk loader')
# Need to delete pending first in case we are doing delete & insert pairs
self._delete_pending_batch(parent_stats=parent_stats)
# Need to update pending first in case we are doing update & insert pairs
self._update_pending_batch(parent_stats=parent_stats)
if len(self.pending_insert_rows) == 0:
return
if parent_stats is not None:
stats = parent_stats
# Keep track of which parent last used pending inserts
self.pending_insert_stats = stats
else:
stats = self.pending_insert_stats
if stats is None:
stats = self.get_stats_entry(stat_name, parent_stats=parent_stats)
if self.insert_method == Table.InsertMethod.insert_values_list:
prepare_stats = self.get_stats_entry('prepare ' + stat_name, parent_stats=stats)
prepare_stats.print_start_stop_times = False
prepare_stats.timer.start()
pending_insert_statements = StatementQueue(execute_with_binds=False)
for new_row in self.pending_insert_rows:
if new_row.status != RowStatus.deleted:
stmt_key = new_row.positioned_column_set
stmt = pending_insert_statements.get_statement_by_key(stmt_key)
if stmt is None:
prepare_stats['statements prepared'] += 1
stmt = f"INSERT INTO {self.qualified_table_name} ({', '.join(new_row.columns_in_order)}) VALUES {{values}}"
pending_insert_statements.add_statement(stmt_key, stmt)
prepare_stats['rows prepared'] += 1
stmt_values = new_row.values()
pending_insert_statements.append_values_by_key(stmt_key, stmt_values)
# Change the row status to existing, now any updates should be via update statements instead of in-memory updates
new_row.status = RowStatus.existing
prepare_stats.timer.stop()
# TODO: Slow performance on big varchars -- possible cause:
# By default pyodbc reads all text columns as SQL_C_WCHAR buffers and decodes them using UTF-16LE.
# When writing, it always encodes using UTF-16LE into SQL_C_WCHAR buffers.
# Connection.setencoding and Connection.setdecoding functions can override
try:
db_stats = self.get_stats_entry(stat_name + ' database execute', parent_stats=stats)
db_stats.print_start_stop_times = False
db_stats.timer.start()
rows_affected = pending_insert_statements.execute(self.connection())
db_stats.timer.stop()
db_stats['rows inserted'] += rows_affected
del self.pending_insert_rows
self.pending_insert_rows = list()
except Exception as e:
# self.log.error(traceback.format_exc())
self.log.exception(e)
self.log.error("Bulk insert failed. Applying as single inserts to find error row...")
self.rollback()
self.begin()
# Retry one at a time
try:
pending_insert_statements.execute_singly(self.connection())
# If that didn't cause the error... re raise the original error
self.log.error(
f"Single inserts on {self} did not produce the error. Original error will be issued below."
)
self.rollback()
raise
except Exception as e_single:
self.log.error(f"Single inserts got error {e_single}")
self.log.exception(e_single)
raise
else:
prepare_stats = self.get_stats_entry('prepare ' + stat_name, parent_stats=stats)
prepare_stats.print_start_stop_times = False
prepare_stats.timer.start()
pending_insert_statements = StatementQueue()
for new_row in self.pending_insert_rows:
if new_row.status != RowStatus.deleted:
stmt_key = new_row.positioned_column_set
stmt = pending_insert_statements.get_statement_by_key(stmt_key)
if stmt is None:
prepare_stats['statements prepared'] += 1
stmt = self._insert_stmt()
for c in new_row:
try:
col_obj = self.get_column(c)
stmt = stmt.values({col_obj.name: bindparam(col_obj.name, type_=col_obj.type)})
except KeyError:
self.log.error(f'Extra column found in pending_insert_rows row {new_row}')
raise
# this was causing InterfaceError: (cx_Oracle.InterfaceError) not a query
# stmt = stmt.compile()
pending_insert_statements.add_statement(stmt_key, stmt)
prepare_stats['rows prepared'] += 1
stmt_values = dict()
for c in new_row:
col_obj = self.get_column(c)
stmt_values[col_obj.name] = new_row[c]
pending_insert_statements.append_values_by_key(stmt_key, stmt_values)
# Change the row status to existing, now any updates should be via update statements
new_row.status = RowStatus.existing
prepare_stats.timer.stop()
try:
db_stats = self.get_stats_entry(stat_name + ' database execute', parent_stats=stats)
db_stats.print_start_stop_times = False
db_stats.timer.start()
rows_affected = pending_insert_statements.execute(self.connection())
db_stats.timer.stop()
db_stats['rows inserted'] += rows_affected
del self.pending_insert_rows
self.pending_insert_rows = list()
except SQLAlchemyError as e:
self.log.error(
f"Bulk insert on {self} failed with error {e}. Applying as single inserts to find error row..."
)
self.rollback()
self.begin()
# Retry one at a time
try:
pending_insert_statements.execute_singly(self.connection())
# If that didn't cause the error... re raise the original error
self.log.error(
f"Single inserts on {self} did not produce the error. Original error will be issued below."
)
self.rollback()
raise
except Exception as e_single:
self.log.error(f"Single inserts got error {e_single}")
self.log.exception(e_single)
raise
[docs]
def insert_row(
self,
source_row: Row,
additional_insert_values: Optional[dict] = None,
maintain_cache: Optional[bool] = None,
source_excludes: Optional[frozenset] = None,
target_excludes: Optional[frozenset] = None,
stat_name: str = 'insert',
parent_stats: Optional[Statistics] = None,
) -> Row:
"""
Inserts a row into the database (batching rows as batch_size)
Parameters
----------
maintain_cache
source_row
The row with values to insert
additional_insert_values:
Values to add / override in the row before inserting.
source_excludes:
set of source columns to exclude
target_excludes
set of target columns to exclude
stat_name
parent_stats
Returns
-------
new_row
"""
stats = self.get_stats_entry(stat_name, parent_stats=parent_stats)
stats.timer.start()
self.begin()
new_row = self.build_row(
source_row=source_row,
source_excludes=source_excludes,
target_excludes=target_excludes,
parent_stats=stats,
)
self.autogenerate_key(new_row, force_override=False)
if self.delete_flag is not None:
if self.delete_flag not in new_row or new_row[self.delete_flag] is None:
new_row.set_keeping_parent(self.delete_flag, self.delete_flag_no)
# Set the last update date
if self.last_update_date is not None:
self.set_last_update_date(new_row)
if self.trace_data:
self.log.debug(f"{self} Raw row being inserted:\n{new_row.str_formatted()}")
if self.in_bulk_mode:
if not self.upsert_called:
if self._bulk_iter_sentinel is None:
self.log.info(f"Setting up {self} worker to collect insert_row rows")
# It does make sense to default to not caching inserted rows
# UNLESS this table is also used as a lookup. Lookups should have been defined.
if len(self.lookups) == 0:
self.maintain_cache_during_load = False
self._bulk_iter_sentinel = StopIteration
self._bulk_iter_queue = Queue()
# row_iter = iter(self._bulk_iter_queue.get, self._bulk_iter_sentinel)
row_iter = self._bulk_iter_queue
self._bulk_iter_worker = spawn(
self.bulk_loader.load_from_iterable,
row_iter,
table_object=self,
progress_frequency=None,
)
self._bulk_iter_queue.put(new_row)
self._bulk_iter_max_q_length = max(self._bulk_iter_max_q_length, len(self._bulk_iter_queue))
# Allow worker to act if it is ready
sleep()
else:
if self.batch_size > 1:
new_row.status = RowStatus.insert
self.pending_insert_rows.append(new_row)
if len(self.pending_insert_rows) >= self.batch_size:
self._insert_pending_batch(parent_stats=stats)
else:
# Immediate insert
new_row.status = RowStatus.existing
try:
stmt_values = dict()
for c in new_row:
col_obj = self.get_column(c)
stmt_values[col_obj.name] = new_row[c]
result = self.execute(self._insert_stmt(), stmt_values)
except Exception as e:
self.log.error(f"Error with source_row = {source_row}")
self.log.error(f"Error with inserted row = {new_row.str_formatted()}")
raise e
stats['rows inserted'] += result.rowcount
result.close()
if maintain_cache is None:
maintain_cache = self.maintain_cache_during_load
if maintain_cache:
self.cache_row(new_row, allow_update=False)
stats.timer.stop()
return new_row
[docs]
def insert(
self,
source_row: Union[MutableMapping, Iterable[MutableMapping]],
additional_insert_values: Optional[dict] = None,
source_excludes: Optional[frozenset] = None,
target_excludes: Optional[frozenset] = None,
parent_stats: Optional[Statistics] = None,
**kwargs
):
"""
Insert a row or list of rows in the table.
Parameters
----------
source_row: :class:`Row` or list thereof
Row(s) to insert
additional_insert_values: dict
Additional values to set on each row.
source_excludes
list of Row source columns to exclude when mapping to this Table.
target_excludes
list of Table columns to exclude when mapping from the source Row(s)
parent_stats: bi_etl.statistics.Statistics
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
"""
if isinstance(source_row, MutableMapping):
if not isinstance(source_row, self.row_object):
warnings.warn('Passing types other than Row to insert is slower.')
source_row = self.row_object(data=source_row, iteration_header=self.empty_iteration_header)
return self.insert_row(
source_row,
additional_insert_values=additional_insert_values,
source_excludes=source_excludes,
target_excludes=target_excludes,
parent_stats=parent_stats,
**kwargs
)
else:
for row in source_row:
if not isinstance(row, self.row_object):
warnings.warn('Passing types other than Row to insert is slower.')
row = self.row_object(data=row, iteration_header=self.empty_iteration_header)
self.insert_row(
row,
additional_insert_values=additional_insert_values,
source_excludes=source_excludes,
target_excludes=target_excludes,
parent_stats=parent_stats,
**kwargs
)
return None
def _delete_pending_batch(
self,
stat_name: str = 'delete',
parent_stats: Iterable[Statistics] = None,
):
if self.in_bulk_mode:
raise InvalidOperation('_delete_pending_batch not allowed in bulk mode')
if self.pending_delete_statements.row_count > 0:
if parent_stats is not None:
stats = parent_stats
# Keep track of which parent last used pending delete
self.pending_delete_stats = stats
else:
stats = self.pending_delete_stats
if stats is None:
stats = self.get_stats_entry(stat_name, parent_stats=parent_stats)
stats['rows batch deleted'] += self.pending_delete_statements.execute(self.connection())
def _delete_stmt(self):
# pylint: disable=no-value-for-parameter
return self.table.delete()
[docs]
def delete(
self,
key_values: Union[Iterable, MutableMapping],
lookup_name: Optional[str] = None,
key_names: Optional[Iterable] = None,
maintain_cache: Optional[bool] = None,
stat_name: str = 'delete',
parent_stats: Optional[Statistics] = None,
):
"""
Delete rows matching key_values.
If caching is enabled, don't use key_names (use lookup_name or default to PK) or
an expensive scan of the cache needs to be performed.
Parameters
----------
key_values: dict (or list)
Container of key values to use to identify a row(s) to delete.
lookup_name: str
Optional lookup name those key values are from.
key_names
Optional list of column names those key values are from.
maintain_cache: boolean
Maintain the cache when deleting. Can be expensive if key_names is used because
scan of the entire cache needs to be performed.
Defaults to True
stat_name: string
Name of this step for the ETLTask statistics.
parent_stats: bi_etl.statistics.Statistics
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
"""
stats = self.get_stats_entry(stat_name, parent_stats=parent_stats)
stats.timer.start()
key_values_dict, lookup_name = self._generate_key_values_dict(key_names, key_values, lookup_name)
# Don't use key_names or key_values anymore, use key_values_dict
del key_names
del key_values
if self.in_bulk_mode:
maintain_cache = True
else:
self.begin()
if maintain_cache is None:
maintain_cache = self.maintain_cache_during_load and self.cache_clean
if not maintain_cache:
self.cache_clean = False
if self.batch_size > 1:
# Bacth-ed deletes
# This mode is difficult to buffer. We might have multiple delete statements used, so we have
# to maintain a dictionary of statements by tuple(key_names) and a buffer for each
delete_stmt_key = frozenset(key_values_dict.keys())
stmt = self.pending_delete_statements.get_statement_by_key(delete_stmt_key)
if stmt is None:
stmt = self._delete_stmt()
for key_name, key_value in key_values_dict.items():
key = self.get_column(key_name)
stmt = stmt.where(key == bindparam(key.name, type_=key.type))
stmt = stmt.compile()
self.pending_delete_statements.add_statement(delete_stmt_key, stmt)
self.pending_delete_statements.append_values_by_key(delete_stmt_key, key_values_dict)
if len(self.pending_delete_statements) >= self.batch_size:
self._delete_pending_batch(parent_stats=stats)
else:
# Deletes issued as we get them
stmt = self._delete_stmt()
for key_name, key_value in key_values_dict.items():
key = self.get_column(key_name)
stmt = stmt.where(key == key_value)
del_result = self.execute(stmt)
stats['rows deleted'] += del_result.rowcount
del_result.close()
if maintain_cache:
if lookup_name is not None:
full_row = self.get_by_lookup(lookup_name, key_values_dict, parent_stats=stats)
self.uncache_row(full_row)
else:
key_columns_set = set(key_values_dict.keys())
found_matching_lookup = False
for lookup in self.lookups.values():
if set(lookup.lookup_keys) == key_columns_set:
found_matching_lookup = True
full_row = lookup.find(key_values_dict, stats=stats)
self.uncache_row(full_row)
if not found_matching_lookup:
warnings.warn(
f'{self}.delete called with maintain_cache=True and a set of keys that do not match any lookup.'
f'This will be very slow. Keys used = {key_columns_set}'
)
self.uncache_where(key_names=key_values_dict.keys(), key_values_dict=key_values_dict)
stats.timer.stop()
# TODO: Add lookup parameter and change to allow set of be of that instead of keys
[docs]
def delete_not_in_set(
self,
set_of_key_tuples: set,
lookup_name: Optional[str] = None,
criteria_list: Optional[list] = None,
criteria_dict: Optional[dict] = None,
use_cache_as_source: Optional[bool] = True,
stat_name: str = 'delete_not_in_set',
progress_frequency: Optional[int] = None,
parent_stats: Optional[Statistics] = None,
**kwargs
):
"""
WARNING: This does physical deletes !! See :meth:`logically_delete_in_set` for logical deletes.
Deletes rows matching criteria that are not in the list_of_key_tuples pass in.
Parameters
----------
set_of_key_tuples
List of tuples comprising the primary key values.
This list represents the rows that should *not* be deleted.
lookup_name: str
The name of the lookup to use to find key tuples.
criteria_list : string or list of strings
Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`.
https://goo.gl/JlY9us
criteria_dict : dict
Dict keys should be columns, values are set using = or in
use_cache_as_source: bool
Attempt to read existing rows from the cache?
stat_name: string
Name of this step for the ETLTask statistics. Default = 'delete_not_in_set'
progress_frequency: int
How often (in seconds) to output progress messages. Default 10. None for no progress messages.
parent_stats: bi_etl.statistics.Statistics
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
"""
stats = self.get_unique_stats_entry(stat_name, parent_stats=parent_stats)
stats['rows read'] = 0
stats['rows deleted'] = 0
stats.timer.start()
deleted_rows = list()
self.begin()
if progress_frequency is None:
progress_frequency = self.progress_frequency
progress_timer = Timer()
# Turn off read progress reports
saved_progress_frequency = self.progress_frequency
self.progress_frequency = None
if lookup_name is None:
lookup = self.get_nk_lookup()
else:
lookup = self.get_lookup(lookup_name)
# Builds static list of cache content so that we can modify (physically delete) cache rows
# while iterating over
row_list = list(
self.where(
criteria_list=criteria_list,
criteria_dict=criteria_dict,
use_cache_as_source=use_cache_as_source,
parent_stats=stats
)
)
for row in row_list:
stats['rows read'] += 1
existing_key = lookup.get_hashable_combined_key(row)
if 0 < progress_frequency <= progress_timer.seconds_elapsed:
progress_timer.reset()
self.log.info(
f"delete_not_in_set current row={stats['rows read']} key={existing_key} deletes_done = {stats['rows deleted']}"
)
if existing_key not in set_of_key_tuples:
stats['rows deleted'] += 1
deleted_rows.append(row)
# In bulk mode we just need to remove them from the cache,
# which is done below where we loop over deleted_rows
if not self.in_bulk_mode:
self.delete(key_values=row, lookup_name=self._get_pk_lookup_name(), maintain_cache=False)
for row in deleted_rows:
self.uncache_row(row)
stats.timer.stop()
# Restore saved progress_frequency
self.progress_frequency = saved_progress_frequency
[docs]
def delete_not_processed(
self,
criteria_list: Optional[list] = None,
criteria_dict: Optional[dict] = None,
use_cache_as_source: bool = True,
allow_delete_all: bool = False,
stat_name: str = 'delete_not_processed',
parent_stats: Optional[Statistics] = None,
**kwargs
):
"""
WARNING: This does physical deletes !! See :meth:`logically_delete_not_processed` for logical deletes.
Physically deletes rows matching criteria that are not in the Table memory of rows passed to :meth:`upsert`.
Parameters
----------
criteria_list : string or list of strings
Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`.
https://goo.gl/JlY9us
criteria_dict : dict
Dict keys should be columns, values are set using = or in
use_cache_as_source: bool
Attempt to read existing rows from the cache?
allow_delete_all:
Allow this method to delete ALL rows if no source rows were processed
stat_name: string
Name of this step for the ETLTask statistics.
parent_stats: bi_etl.statistics.Statistics
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
"""
assert self.track_source_rows, "delete_not_processed can't be used if we don't track source rows"
if not allow_delete_all and (self.source_keys_processed is None or len(self.source_keys_processed) == 0):
# We don't want to logically delete all the rows
# But that's only an issue if there are target rows
if any(True for _ in self.where(criteria_list=criteria_list, criteria_dict=criteria_dict, )):
raise RuntimeError(f"{stat_name} called before any source rows were processed on {self}.")
self.delete_not_in_set(
set_of_key_tuples=self.source_keys_processed,
criteria_list=criteria_list,
criteria_dict=criteria_dict,
stat_name=stat_name,
parent_stats=parent_stats
)
self.source_keys_processed = set()
[docs]
def logically_delete_not_in_set(
self,
set_of_key_tuples: set,
lookup_name: Optional[str] = None,
criteria_list: Optional[list] = None,
criteria_dict: Optional[dict] = None,
use_cache_as_source: bool = True,
stat_name: str = 'logically_delete_not_in_set',
progress_frequency: Optional[int] = 10,
parent_stats: Optional[Statistics] = None,
**kwargs
):
"""
Logically deletes rows matching criteria that are not in the list_of_key_tuples pass in.
Parameters
----------
set_of_key_tuples
List of tuples comprising the primary key values.
This list represents the rows that should *not* be logically deleted.
lookup_name:
Name of the lookup to use
criteria_list :
Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`.
https://goo.gl/JlY9us
criteria_dict :
Dict keys should be columns, values are set using = or in
use_cache_as_source: bool
Attempt to read existing rows from the cache?
stat_name:
Name of this step for the ETLTask statistics. Default = 'delete_not_in_set'
progress_frequency:
How often (in seconds) to output progress messages. Default = 10.
parent_stats:
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
kwargs:
IF child class HistoryTable
effective_date:
The effective date to use for this operation.
"""
if self._logical_delete_update is None:
iteration_header = self.generate_iteration_header(
logical_name='logically_deleted',
columns_in_order=[self.delete_flag],
)
self._logical_delete_update = self.row_object(iteration_header=iteration_header)
self._logical_delete_update[self.delete_flag] = self.delete_flag_yes
if criteria_dict is None:
criteria_dict = dict()
# Do no process rows that are already deleted
criteria_dict = {self.delete_flag: self.delete_flag_no}
self.update_not_in_set(
updates_to_make=self._logical_delete_update,
set_of_key_tuples=set_of_key_tuples,
lookup_name=lookup_name,
criteria_list=criteria_list,
criteria_dict=criteria_dict,
use_cache_as_source=use_cache_as_source,
progress_frequency=progress_frequency,
stat_name=stat_name,
parent_stats=parent_stats,
**kwargs
)
[docs]
def logically_delete_not_processed(
self,
criteria_list: Optional[list] = None,
criteria_dict: Optional[dict] = None,
use_cache_as_source: bool = True,
allow_delete_all: bool = False,
stat_name='logically_delete_not_processed',
parent_stats: Optional[Statistics] = None,
**kwargs
):
"""
Logically deletes rows matching criteria that are not in the Table memory of rows passed to :meth:`upsert`.
Parameters
----------
criteria_list:
Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`.
https://goo.gl/JlY9us
criteria_dict:
Dict keys should be columns, values are set using = or in
use_cache_as_source: bool
Attempt to read existing rows from the cache?
allow_delete_all:
Allow this method to delete all rows.
Defaults to False in case an error preventing the processing of any rows.
parent_stats: bi_etl.statistics.Statistics
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
stat_name: string
Name of this step for the ETLTask statistics. Default = 'logically_delete_not_processed'
"""
if not self.track_source_rows:
raise ValueError(
"logically_delete_not_processed can't be used if we don't track source rows"
)
if not allow_delete_all and (self.source_keys_processed is None or len(self.source_keys_processed) == 0):
# We don't want to logically delete all the rows
# But that's only an issue if there are target rows
if any(True for _ in self.where(criteria_list=criteria_list, criteria_dict=criteria_dict, )):
raise RuntimeError(f"{stat_name} called before any source rows were processed on {self}.")
self.logically_delete_not_in_set(
set_of_key_tuples=self.source_keys_processed,
criteria_list=criteria_list,
criteria_dict=criteria_dict,
stat_name='logically_delete_not_processed',
parent_stats=parent_stats,
**kwargs
)
self.source_keys_processed = set()
[docs]
def logically_delete_not_in_source(
self,
source: ReadOnlyTable,
source_criteria_list: Optional[list] = None,
source_criteria_dict: Optional[dict] = None,
target_criteria_list: Optional[list] = None,
target_criteria_dict: Optional[dict] = None,
use_cache_as_source: Optional[bool] = True,
parent_stats: Optional[Statistics] = None
):
"""
Logically deletes rows matching criteria that are not in the source component passed to this method.
The primary use case for this method is when the upsert method is only passed new/changed records and so cannot
build a complete set of source keys in source_keys_processed.
Parameters
----------
source:
The source to read to get the source keys.
source_criteria_list : string or list of strings
Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`.
https://goo.gl/JlY9us
source_criteria_dict:
Dict keys should be columns, values are set using = or in
target_criteria_list : string or list of strings
Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`.
https://goo.gl/JlY9us
target_criteria_dict:
Dict keys should be columns, values are set using = or in
use_cache_as_source:
Attempt to read existing rows from the cache?
parent_stats:
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
"""
self.log.info("Processing deletes")
self.log.info("...getting source keys")
set_of_source_keys = set()
for row in source.where(
column_list=source.primary_key,
criteria_list=source_criteria_list,
criteria_dict=source_criteria_dict,
parent_stats=parent_stats,
):
set_of_source_keys.add(source.get_primary_key_value_tuple(row))
self.log.info("...logically_delete_not_in_set of source keys")
self.logically_delete_not_in_set(
set_of_source_keys,
criteria_list=target_criteria_list,
criteria_dict=target_criteria_dict,
use_cache_as_source=use_cache_as_source,
parent_stats=parent_stats,
)
[docs]
def get_current_time(self) -> datetime:
if self.use_utc_times:
# Returns the current UTC date and time, as a naive datetime object.
return datetime.utcnow()
else:
return datetime.now()
[docs]
def set_last_update_date(self, row):
last_update_date_col = self.get_column_name(self.last_update_date)
last_update_coerce = self.get_coerce_method(last_update_date_col)
last_update_value = last_update_coerce(self.get_current_time())
row.set_keeping_parent(last_update_date_col, last_update_value)
[docs]
def get_bind_name(self, column_name: str) -> str:
if self._bind_name_map is None:
max_len = self.database.dialect.max_identifier_length
self._bind_name_map = dict()
bind_names = set()
for other_column_name in self.column_names:
bind_name = f'b_{other_column_name}'
if len(bind_name) > max_len:
bind_name = bind_name[:max_len - 2]
unique_name_found = False
next_disamb = 0
while not unique_name_found:
if next_disamb >= 1:
bind_name = bind_name[:max_len - len(str(next_disamb))] + str(next_disamb)
if bind_name in bind_names:
next_disamb += 1
else:
unique_name_found = True
bind_names.add(bind_name)
self._bind_name_map[other_column_name] = bind_name
return self._bind_name_map[column_name]
else:
return self._bind_name_map[column_name]
def _update_pending_batch(
self,
stat_name: str = 'update',
parent_stats: Optional[Statistics] = None
):
"""
Parameters
----------
stat_name: str
Name of this step for the ETLTask statistics. Default = 'update'
parent_stats: Statistics
Returns
-------
"""
if self.in_bulk_mode:
raise InvalidOperation('_update_pending_batch not allowed in bulk mode')
if len(self.pending_update_rows) == 0:
return
assert self.primary_key, "add_pending_update called for table with no primary key"
if parent_stats is not None:
stats = parent_stats
# Keep track of which parent stats last called us
self.pending_update_stats = stats
else:
stats = self.pending_update_stats
if stats is None:
stats = self.get_stats_entry(stat_name, parent_stats=parent_stats)
stats.timer.start()
self.begin()
pending_update_statements = StatementQueue()
for pending_update_rec in self.pending_update_rows.values():
row_dict = pending_update_rec.update_row_new_values
if row_dict.status not in [RowStatus.deleted, RowStatus.insert]:
update_where_values = pending_update_rec.update_where_values
update_where_keys = pending_update_rec.update_where_keys
stats['update rows sent to db'] += 1
update_stmt_key = (update_where_keys, row_dict.column_set)
update_stmt = pending_update_statements.get_statement_by_key(update_stmt_key)
if update_stmt is None:
update_stmt = self._update_stmt()
for key_column in update_where_keys:
key = self.get_column(key_column)
bind_name = f"_{key.name}"
# Note SQLAlchemy takes care of converting to positional “qmark” bind parameters as needed
update_stmt = update_stmt.where(key == bindparam(bind_name, type_=key.type))
for c in row_dict.columns_in_order:
col_obj = self.get_column(c)
bind_name = col_obj.name
update_stmt = update_stmt.values({c: bindparam(bind_name, type_=col_obj.type)})
update_stmt = update_stmt.compile()
pending_update_statements.add_statement(update_stmt_key, update_stmt)
stmt_values = dict()
for key_column, key_value in zip(update_where_keys, update_where_values):
key = self.get_column(key_column)
bind_name = f"_{key.name}"
stmt_values[bind_name] = key_value
for c in row_dict.columns_in_order:
# Since we know we aren't updating the key, don't send the keys in the values clause
col_obj = self.get_column(c)
bind_name = col_obj.name
stmt_values[bind_name] = row_dict[c]
pending_update_statements.append_values_by_key(update_stmt_key, stmt_values)
row_dict.status = RowStatus.existing
try:
db_stats = self.get_stats_entry('DB Update', parent_stats=stats)
db_stats['update statements sent to db'] += 1
db_stats.timer.start()
rows_applied = pending_update_statements.execute(self.connection())
db_stats.timer.stop()
db_stats['applied rows'] += rows_applied
# len(self.pending_update_rows)
except Exception as e:
# Retry one at a time
self.log.error(
f"Bulk update on {self} failed with error {e}. Applying as single updates to find error row..."
)
connection = self.connection(connection_name='singly')
for (stmt, row_dict) in pending_update_statements.iter_single_statements():
try:
connection.execute(stmt, row_dict)
except Exception as e:
self.log.error(f"Error with row {dict_to_str(row_dict)}")
raise e
# If that didn't cause the error... re-raise the original error
self.log.error(f"Single updates on {self} did not produce the error. Original error will be issued below.")
self.rollback()
raise e
self.pending_update_rows.clear()
stats.timer.stop()
def _add_pending_update(
self,
update_where_keys: Tuple[str],
update_where_values: Iterable[Any],
update_row_new_values: Row,
parent_stats: Optional[Statistics] = None,
):
assert self.primary_key, "add_pending_update called for table with no primary key"
assert update_row_new_values.status != RowStatus.insert, "add_pending_update called with row that's pending insert"
assert update_row_new_values.status != RowStatus.deleted, "add_pending_update called with row that's pending delete"
key_tuple = tuple(update_where_values)
pending_update_rec = PendingUpdate(
update_where_keys,
update_where_values,
update_row_new_values
)
if key_tuple not in self.pending_update_rows:
self.pending_update_rows[key_tuple] = pending_update_rec
else:
# Apply updates to the existing pending update
existing_update_rec = self.pending_update_rows[key_tuple]
existing_row = existing_update_rec.update_row_new_values
for col, value in update_row_new_values.items():
existing_row[col] = value
if len(self.pending_update_rows) >= self.batch_size:
self._update_pending_batch(parent_stats=parent_stats)
[docs]
def apply_updates(
self,
row: Row,
changes_list: Optional[MutableSequence[ColumnDifference]] = None,
additional_update_values: MutableMapping = None,
add_to_cache: bool = True,
allow_insert=True,
update_apply_entire_row: bool = False,
stat_name: str = 'update',
parent_stats: Optional[Statistics] = None,
**kwargs
):
"""
This method should only be called with a row that has already been transformed into the correct datatypes
and column names.
The update values can be in any of the following parameters
- row (also used for PK)
- changes_list
- additional_update_values
Parameters
----------
row:
The row to update with (needs to have at least PK values)
The values are used for the WHERE and so should be before any updates to make.
The updates to make can be sent via changes_list or additional_update_values.
changes_list:
A list of ColumnDifference objects to apply to the row
additional_update_values:
A Row or dict of additional values to apply to the main row
add_to_cache:
Should this method update the cache (not if caller will)
allow_insert: boolean
Allow this method to insert a new row into the cache
update_apply_entire_row:
Should the update set all non-key columns or just the changed values?
stat_name:
Name of this step for the ETLTask statistics. Default = 'update'
parent_stats:
kwargs:
Not used in this version
"""
assert self.primary_key, "apply_updates called for table with no primary key"
assert row.status != RowStatus.deleted, f"apply_updates called for deleted row {row}"
stats = self.get_stats_entry(stat_name, parent_stats=parent_stats)
self.pending_update_stats = stats
stats.timer.start()
# Set the last update date
if self.last_update_date is not None:
self.set_last_update_date(row)
# Use row to get PK where values BEFORE we apply extra updates to the row
# This allows the update WHERE and SET to have different values for the PK fields
update_where_keys = self.primary_key_tuple
update_where_values = [row[c] for c in update_where_keys]
if changes_list is not None:
for d_chg in changes_list:
if isinstance(d_chg, ColumnDifference):
row[d_chg.column_name] = d_chg.new_value
else:
# Handle if we are sent a Mapping instead of a list of ColumnDifference for changes_list
# d_chg would then be the Mapping key (column name)
# noinspection PyTypeChecker
row[d_chg] = changes_list[d_chg]
if additional_update_values is not None:
for col_name, value in additional_update_values.items():
row[col_name] = value
if add_to_cache and self.maintain_cache_during_load:
self.cache_row(row, allow_update=True, allow_insert=allow_insert)
if not self.in_bulk_mode:
# Check that the row isn't pending an insert
if row.status != RowStatus.insert:
if update_apply_entire_row or changes_list is None:
row.status = RowStatus.update_whole
self._add_pending_update(
update_where_keys=update_where_keys,
update_where_values=update_where_values,
update_row_new_values=row,
parent_stats=stats
)
else:
# Support partial updates.
# We could save significant network transit times in some cases if we didn't send the whole row
# back as an update. However, we might also not want a different update statement for each
# combination of updated columns. Unfortunately, having code in update_pending_batch scan the
# combinations for the most efficient ways to group the updates would also likely be slow.
update_columns = {d_chg.column_name for d_chg in changes_list}
if additional_update_values is not None:
update_columns.update(additional_update_values.keys())
partial_row = row.subset(keep_only=update_columns)
partial_row.status = RowStatus.update_partial
self._add_pending_update(
update_where_keys=update_where_keys,
update_where_values=update_where_values,
update_row_new_values=partial_row,
parent_stats=stats
)
if stats is not None:
stats['update called'] += 1
else:
# The update to the pending row is all we need to do
if stats is not None:
stats['updated pending insert'] += 1
stats.timer.stop()
def _update_stmt(self) -> UpdateBase:
return self.table.update()
[docs]
def update_where_pk(
self,
updates_to_make: Row,
key_values: Union[Row, dict, list] = None,
source_excludes: Optional[frozenset] = None,
target_excludes: Optional[frozenset] = None,
stat_name: str = 'update_where_pk',
parent_stats: Optional[Statistics] = None,
):
"""
Updates the table using the primary key for the where clause.
Parameters
----------
updates_to_make:
Updates to make to the rows matching the criteria
Can also be used to pass the key_values, so you can pass a single
:class:`~bi_etl.components.row.row_case_insensitive.Row` or ``dict``
to the call and have it automatically get the filter values and updates from it.
key_values:
Optional. dict or list of key to apply as criteria
source_excludes:
Optional. set of column names to exclude from the source row
source columns to exclude when mapping to this Table.
target_excludes:
Optional. set of target column names to exclude when mapping from the source to target
:class:`~bi_etl.components.row.row_case_insensitive.Row` (s)
stat_name:
Name of this step for the ETLTask statistics. Default = 'upsert_by_pk'
parent_stats:
Optional. Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
"""
stats = self.get_stats_entry(stat_name, parent_stats=parent_stats)
source_mapped_as_target_row = self.build_row(
source_row=updates_to_make,
source_excludes=source_excludes,
target_excludes=target_excludes,
parent_stats=stats,
)
key_names = self.primary_key
key_values_dict, _ = self._generate_key_values_dict(key_names, key_values, other_values_dict=updates_to_make)
# Don't use key_names or key_values anymore, use key_values_dict
del key_names
del key_values
source_mapped_as_target_row.update(key_values_dict)
self.apply_updates(
source_mapped_as_target_row,
allow_insert=False,
parent_stats=stats,
)
[docs]
def update(
self,
updates_to_make: Union[Row, dict],
key_names: Optional[Iterable] = None,
key_values: Optional[Iterable] = None,
lookup_name: Optional[str] = None,
update_all_rows: bool = False,
source_excludes: Optional[frozenset] = None,
target_excludes: Optional[frozenset] = None,
stat_name: str = 'direct update',
parent_stats: Optional[Statistics] = None,
connection_name: Optional[str] = None,
):
"""
Directly performs a database update. Invalidates caching.
THIS METHOD IS SLOW! If you have a full target row, use apply_updates instead.
Parameters
----------
updates_to_make:
Updates to make to the rows matching the criteria
Can also be used to pass the key_values, so you can pass a single
:class:`~bi_etl.components.row.row_case_insensitive.Row` or ``dict`` to the call and have it automatically get the
filter values and updates from it.
key_names:
Optional. List of columns to apply criteria too (see ``key_values``).
Defaults to Primary Key columns.
key_values:
Optional. List of values to apply as criteria (see ``key_names``).
If not provided, and ``update_all_rows`` is False, look in ``updates_to_make`` for values.
lookup_name: str
Name of the lookup to use
update_all_rows:
Optional. Defaults to False. If set to True, ``key_names`` and ``key_values`` are not required.
source_excludes
Optional. list of :class:`~bi_etl.components.row.row_case_insensitive.Row` source columns to exclude when mapping to this Table.
target_excludes
Optional. list of Table columns to exclude when mapping from the source :class:`~bi_etl.components.row.row_case_insensitive.Row` (s)
stat_name: str
Name of this step for the ETLTask statistics. Default = 'direct update'
parent_stats: bi_etl.statistics.Statistics
Optional. Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
connection_name:
Name of the pooled connection to use
Defaults to DEFAULT_CONNECTION_NAME
"""
stats = self.get_stats_entry(stat_name, parent_stats=parent_stats)
stats.timer.start()
try:
_ = updates_to_make.columns_in_order
except AttributeError:
warnings.warn('Passing types other than Row to update is slower.')
updates_to_make = self.row_object(data=updates_to_make, iteration_header=self.empty_iteration_header)
# Check if we can pass this of to update_where_pk
if key_names is None and not update_all_rows:
self.update_where_pk(
updates_to_make=updates_to_make,
key_values=key_values,
source_excludes=source_excludes,
target_excludes=target_excludes,
parent_stats=stats
)
return
source_mapped_as_target_row = self.build_row(
source_row=updates_to_make,
source_excludes=source_excludes,
target_excludes=target_excludes,
parent_stats=stats,
)
stmt = self._update_stmt()
if not update_all_rows:
key_values_dict, lookup_name = self._generate_key_values_dict(
key_names,
key_values,
lookup_name=lookup_name,
other_values_dict=updates_to_make
)
# Don't use key_names or key_values anymore, use key_values_dict
del key_names
del key_values
for key_name, key_value in key_values_dict.items():
key = self.get_column(key_name)
stmt = stmt.where(key == key_value)
# We could optionally scan the entire cache and apply updates there
# instead for now, we'll un-cache the row and set the lookups to fall back to the database
# TODO: If we have a lookup based on an updated value, the original value would still be
# in the lookup. We don't know the original without doing a lookup.
self.cache_clean = False
if self.maintain_cache_during_load:
self.uncache_row(key_values_dict)
else:
assert not key_values, "update_all_rows set and yet we got key_values"
# End not update all rows
# Add set statements to the update
for c in source_mapped_as_target_row:
v = source_mapped_as_target_row[c]
stmt = stmt.values({c: v})
self.begin()
stats['direct_updates_sent_to_db'] += 1
result = self.execute(stmt, connection_name=connection_name)
stats['direct_updates_applied_to_db'] += result.rowcount
result.close()
stats.timer.stop()
[docs]
def update_not_in_set(
self,
updates_to_make,
set_of_key_tuples: set,
lookup_name: str = None,
criteria_list: list = None,
criteria_dict: dict = None,
progress_frequency: int = None,
stat_name: str = 'update_not_in_set',
parent_stats: Statistics = None,
**kwargs
):
"""
Applies update to all rows matching criteria that are not in the list_of_key_tuples pass in.
Parameters
----------
updates_to_make: :class:`~bi_etl.components.row.row_case_insensitive.Row`
:class:`~bi_etl.components.row.row_case_insensitive.Row` or dict of updates to make
set_of_key_tuples
List of tuples comprising the primary key values. This list represents the rows that should *not* be updated.
lookup_name: str
The name of the lookup to use to find key tuples.
criteria_list : string or list of strings
Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`.
https://goo.gl/JlY9us
criteria_dict:
Dict keys should be columns, values are set using = or in
stat_name: string
Name of this step for the ETLTask statistics. Default = 'delete_not_in_set'
progress_frequency: int
How often (in seconds) to output progress messages. Default 10. None for no progress messages.
parent_stats: bi_etl.statistics.Statistics
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
"""
stats = self.get_unique_stats_entry(stat_name, parent_stats=parent_stats)
stats['rows read'] = 0
stats['updates count'] = 0
stats.timer.start()
self.begin()
if progress_frequency is None:
progress_frequency = self.progress_frequency
progress_timer = Timer()
# Turn off read progress reports
saved_progress_frequency = self.progress_frequency
self.progress_frequency = None
if lookup_name is None:
lookup = self.get_nk_lookup()
else:
lookup = self.get_lookup(lookup_name)
if criteria_dict is None:
# Default to not processing rows that are already deleted
criteria_dict = {self.delete_flag: self.delete_flag_no}
else:
# Add rule to avoid processing rows that are already deleted
criteria_dict[self.delete_flag] = self.delete_flag_no
# Note, here we select only lookup columns from self
for row in self.where(
column_list=lookup.lookup_keys,
criteria_list=criteria_list,
criteria_dict=criteria_dict,
connection_name='select',
parent_stats=stats
):
if row.status == RowStatus.unknown:
pass
stats['rows read'] += 1
existing_key = lookup.get_hashable_combined_key(row)
if 0 < progress_frequency <= progress_timer.seconds_elapsed:
progress_timer.reset()
self.log.info(
f"update_not_in_set current current row#={stats['rows read']:,} row key={existing_key} updates done so far = {stats['updates count']:,}"
)
if existing_key not in set_of_key_tuples:
stats['updates count'] += 1
try:
# First we need the entire existing row
target_row = lookup.find(row)
except NoResultFound:
raise RuntimeError(
f"keys {row.as_key_value_list} found in database or cache but not found now by get_by_lookup"
)
# Then we can apply the updates to it
self.apply_updates(
row=target_row,
additional_update_values=updates_to_make,
allow_insert=False,
parent_stats=stats,
)
stats.timer.stop()
# Restore saved progress_frequency
self.progress_frequency = saved_progress_frequency
[docs]
def update_not_processed(
self,
update_row,
lookup_name: Optional[str] = None,
criteria_list: Optional[Iterable] = None,
criteria_dict: Optional[MutableMapping] = None,
use_cache_as_source: bool = True,
stat_name: Optional[str] = 'update_not_processed',
parent_stats: Optional[Statistics] = None,
**kwargs
):
"""
Applies update to all rows matching criteria that are not in the Table memory of rows passed to :meth:`upsert`.
Parameters
----------
update_row:
:class:`~bi_etl.components.row.row_case_insensitive.Row` or dict of updates to make
criteria_list :
Each string value will be passed to :meth:`sqlalchemy.sql.expression.Select.where`.
https://goo.gl/JlY9us
criteria_dict:
Dict keys should be columns, values are set using = or in
lookup_name:
Optional lookup name those key values are from.
use_cache_as_source:
Attempt to read existing rows from the cache?
stat_name:
Name of this step for the ETLTask statistics.
parent_stats:
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
kwargs:
IF HistoryTable or child thereof:
effective_date: datetime
The effective date to use for the update
"""
assert self.track_source_rows, "update_not_processed can't be used if we don't track source rows"
if self.source_keys_processed is None or len(self.source_keys_processed) == 0:
# We don't want to logically delete all the rows
# But that's only an issue if there are target rows
if any(
True for _ in self.where(
criteria_list=criteria_list,
criteria_dict=criteria_dict,
)
):
raise RuntimeError(f"{stat_name} called before any source rows were processed.")
self.update_not_in_set(
updates_to_make=update_row,
set_of_key_tuples=self.source_keys_processed,
criteria_list=criteria_list,
criteria_dict=criteria_dict,
use_cache_as_source=use_cache_as_source,
stat_name=stat_name,
parent_stats=parent_stats,
**kwargs
)
self.source_keys_processed = set()
def _set_upsert_mode(self):
if self._bulk_iter_sentinel is not None:
raise ValueError(
"Upsert called on a table that is using insert_row method. "
"Please manually set upsert_called = True before load operations."
)
if self.in_bulk_mode and not self.upsert_called:
# One time setup
pk_lookup = self.get_pk_lookup()
pk_lookup.cache_enabled = True
self.cache_filled = True
self.cache_clean = True
self.always_fallback_to_db = False
self.upsert_called = True
[docs]
def upsert(
self,
source_row: Union[Row, List[Row]],
lookup_name: Optional[str] = None,
skip_update_check_on: Optional[frozenset] = None,
do_not_update: Optional[Iterable] = None,
additional_update_values: Optional[dict] = None,
additional_insert_values: Optional[dict] = None,
update_callback: Optional[Callable[[list, Row], None]] = None,
insert_callback: Optional[Callable[[Row], None]] = None,
source_excludes: Optional[frozenset] = None,
target_excludes: Optional[frozenset] = None,
stat_name: str = 'upsert',
parent_stats: Statistics = None,
**kwargs
):
"""
Update (if changed) or Insert a row in the table.
This command will look for an existing row in the target table
(using the primary key lookup if no alternate lookup name is provided). If no existing row is found, an insert
will be generated. If an existing row is found, it will be compared to the row passed in. If changes are found,
an update will be generated.
Returns the row found/inserted, with the auto-generated key (if that feature is enabled)
Parameters
----------
source_row
:class:`~bi_etl.components.row.row_case_insensitive.Row` to upsert
lookup_name
The name of the lookup (see :meth:`define_lookup`) to use when searching for an existing row.
skip_update_check_on
List of column names to not compare old vs new for updates.
do_not_update
List of columns to never update.
additional_update_values
Additional updates to apply when updating
additional_insert_values
Additional values to set on each row when inserting.
update_callback:
Function to pass updated rows to. Function should not modify row.
insert_callback:
Function to pass inserted rows to. Function should not modify row.
source_excludes
list of Row source columns to exclude when mapping to this Table.
target_excludes
list of Table columns to exclude when mapping from the source Row(s)
stat_name
Name of this step for the ETLTask statistics. Default = 'upsert'
parent_stats
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
"""
stats = self.get_stats_entry(stat_name, parent_stats=parent_stats)
stats.ensure_exists('upsert source row count')
self._set_upsert_mode()
stats.timer.start()
self.begin()
stats['upsert source row count'] += 1
# Check for existing row
source_mapped_as_target_row = self.build_row(
source_row=source_row,
source_excludes=source_excludes,
target_excludes=target_excludes,
parent_stats=stats,
)
if self.delete_flag is not None:
if self.delete_flag not in source_mapped_as_target_row \
or source_mapped_as_target_row[self.delete_flag] is None:
source_mapped_as_target_row.set_keeping_parent(self.delete_flag, self.delete_flag_no)
try:
# We'll default to using the natural key or primary key or NK based on row columns present
if lookup_name is None:
lookup_name = self.get_default_lookup(source_row.iteration_header)
if isinstance(lookup_name, Lookup):
lookup_object = lookup_name
else:
lookup_object = self.get_lookup(lookup_name)
if not lookup_object.cache_enabled and self.maintain_cache_during_load and self.batch_size > 1:
raise AssertionError("Caching needs to be turned on if batch mode is on!")
if self.in_bulk_mode and not (lookup_object.cache_enabled and self.maintain_cache_during_load):
raise AssertionError("Caching needs to be turned on if bulk mode is on!")
del lookup_object
existing_row = self.get_by_lookup(lookup_name, source_mapped_as_target_row, parent_stats=stats)
if self.trace_data:
lookup_keys = self.get_lookup(lookup_name).get_list_of_lookup_column_values(source_mapped_as_target_row)
else:
lookup_keys = None
if do_not_update is None:
do_not_update = set()
if skip_update_check_on is None:
skip_update_check_on = set()
# In case cache was built using an incomplete row, set those missing columns to None before compare
cache_missing_columns = self.column_names_set - existing_row.column_set
for missing_column in cache_missing_columns:
existing_row.set_keeping_parent(missing_column, None)
changes_list = existing_row.compare_to(
source_mapped_as_target_row,
exclude=do_not_update | skip_update_check_on,
coerce_types=False,
)
conditional_changes = existing_row.compare_to(
source_mapped_as_target_row,
compare_only=skip_update_check_on,
coerce_types=False,
)
if self.track_update_columns:
col_stats = self.get_stats_entry('updated columns', parent_stats=stats)
for chg in changes_list:
col_stats.add_to_stat(key=chg.column_name, increment=1)
if self.trace_data:
self.log.debug(
f"{lookup_keys} {chg.column_name} changed from {chg.old_value} to {chg.new_value}"
)
if len(changes_list) > 0:
conditional_changes_msg = 'Conditional change applied'
else:
conditional_changes_msg = 'Conditional change NOT applied'
for chg in conditional_changes:
col_stats.add_to_stat(key=chg.column_name, increment=1)
if self.trace_data:
self.log.debug(
f"{lookup_keys} {conditional_changes_msg}: {chg.column_name} changed from {chg.old_value} to {chg.new_value}"
)
if len(changes_list) > 0:
changes_list = list(changes_list) + list(conditional_changes)
self.apply_updates(
row=existing_row,
changes_list=changes_list,
additional_update_values=additional_update_values,
parent_stats=stats
)
if update_callback:
update_callback(changes_list, existing_row)
target_row = existing_row
except NoResultFound:
new_row = source_mapped_as_target_row
if additional_insert_values:
for colName, value in additional_insert_values.items():
new_row[colName] = value
self.autogenerate_key(new_row, force_override=True)
self.insert_row(new_row, parent_stats=stats)
stats['insert new called'] += 1
if insert_callback:
insert_callback(new_row)
target_row = new_row
if self.track_source_rows:
# Keep track of source records, so we can check if target rows don't exist in source
# Note: We use the target_row here since it has already been translated to match the target table
# It's also important that it have the existing surrogate key (if any)
self.source_keys_processed.add(self.get_natural_key_tuple(target_row))
stats.timer.stop()
return target_row
[docs]
def upsert_by_pk(
self,
source_row: Row,
stat_name='upsert_by_pk',
parent_stats: Optional[Statistics] = None,
**kwargs
):
"""
Used by :meth:`bi_etl.components.table.Table.upsert_special_values_rows` to find and
update rows by the full PK.
Not expected to be useful outside that use case.
Parameters
----------
source_row: :class:`~bi_etl.components.row.row_case_insensitive.Row`
Row to upsert
stat_name: string
Name of this step for the ETLTask statistics. Default = 'upsert_by_pk'
parent_stats: bi_etl.statistics.Statistics
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
"""
stats = self.get_stats_entry(stat_name, parent_stats=parent_stats)
stats.timer.start()
self.begin()
source_row = self.build_row(source_row)
if self.track_source_rows:
# Keep track of source records,, so we can check if target rows don't exist in source
self.source_keys_processed.add(self.get_natural_key_tuple(source_row))
try:
existing_row = self.get_by_key(source_row)
changes_list = existing_row.compare_to(source_row, coerce_types=False)
if len(changes_list) > 0:
self.apply_updates(
existing_row,
changes_list=changes_list,
stat_name=stat_name,
parent_stats=stats,
**kwargs
)
except NoResultFound:
self.insert_row(source_row)
stats.timer.stop()
[docs]
def upsert_special_values_rows(
self,
stat_name: str = 'upsert_special_values_rows',
parent_stats: Optional[Statistics] = None,
):
"""
Send all special values rows to upsert to ensure they exist and are current.
Rows come from :meth:`get_missing_row`, :meth:`get_invalid_row`, :meth:`get_not_applicable_row`, :meth:`get_various_row`
Parameters
----------
stat_name: str
Name of this step for the ETLTask statistics. Default = 'upsert_special_values_rows'
parent_stats: bi_etl.statistics.Statistics
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
"""
self.log.info(f"Checking special values rows for {self}")
stats = self.get_stats_entry(stat_name, parent_stats=parent_stats)
stats['calls'] += 1
stats.timer.start()
save_auto_gen = self.auto_generate_key
self.auto_generate_key = False
special_rows = [
self.get_missing_row(),
self.get_invalid_row(),
self.get_not_applicable_row(),
self.get_various_row(),
self.get_none_selected_row()
]
for source_row in special_rows:
self.upsert(source_row, parent_stats=stats, lookup_name=self._get_pk_lookup_name())
if not self.in_bulk_mode:
self.commit(parent_stats=stats)
self.auto_generate_key = save_auto_gen
stats.timer.stop()
[docs]
def truncate(
self,
timeout: int = 60,
stat_name: str = 'truncate',
parent_stats: Optional[Statistics] = None,
):
"""
Truncate the table if possible, else delete all.
Parameters
----------
timeout: int
How long in seconds to wait for the truncate. Oracle only.
stat_name: str
Name of this step for the ETLTask statistics. Default = 'truncate'
parent_stats: bi_etl.statistics.Statistics
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
"""
stats = self.get_stats_entry(stat_name, parent_stats=parent_stats)
stats['calls'] += 1
stats.timer.start()
database_type = self.database.dialect_name
truncate_sql = sqlalchemy.text(f'TRUNCATE TABLE {self.quoted_qualified_table_name}')
with self.database.begin():
if database_type == 'oracle':
self.execute(f'alter session set ddl_lock_timeout={timeout}')
self.execute(truncate_sql)
elif database_type in {'mssql', 'mysql', 'postgresql', 'sybase', 'redshift'}:
self.execute(truncate_sql)
else:
self.execute(self._delete_stmt())
stats.timer.stop()
[docs]
def transaction(
self,
connection_name: Optional[str] = None
):
connection_name = self.database.resolve_connection_name(connection_name)
self._connections_used.add(connection_name)
return self.database.begin(connection_name)
[docs]
def begin(
self,
connection_name: Optional[str] = None
):
if not self.in_bulk_mode:
return self.transaction(connection_name)
[docs]
def bulk_load_from_cache(
self,
temp_table: Optional[str] = None,
stat_name: str = 'bulk_load_from_cache',
parent_stats: Optional[Statistics] = None,
):
if self.in_bulk_mode:
if self._bulk_rollback:
self.log.info("Not performing bulk load due to rollback")
else:
if self.bulk_loader is not None:
stats = self.get_unique_stats_entry(stat_name, parent_stats=parent_stats)
stats.timer.start()
assert isinstance(
self.bulk_loader, BulkLoader
), f'bulk_loader property needs to be instance of bulk_loader not {type(self.bulk_loader)}'
self.close_connections()
if self._bulk_iter_sentinel is None:
if self.upsert_called:
self.log.info(f"Bulk load from lookup cache for {self}")
if temp_table is None:
temp_table = self.qualified_table_name
row_count = self.bulk_loader.load_table_from_cache(self, temp_table)
stats['rows_loaded'] = row_count
else:
self.log.info(
f"Bulk for {self} -- nothing to do. Neither Upsert nor insert_row was called."
)
else:
self.log.info(f"Bulk load from insert_row for {self}")
# Finish the iterator on the queue
self._bulk_iter_queue.put(self._bulk_iter_sentinel)
sleep(1)
self._bulk_iter_max_q_length = max(self._bulk_iter_max_q_length, len(self._bulk_iter_queue))
stats['max_insert_queue'] = self._bulk_iter_max_q_length
stats['ending_insert_queue'] = len(self._bulk_iter_queue)
self.log.info("Waiting for bulk loader")
self._bulk_iter_worker.join()
stats['rows_loaded'] = self._bulk_iter_worker.value
self._bulk_iter_sentinel = None
if self._bulk_iter_worker.exception is not None:
raise self._bulk_iter_worker.exception
stats.timer.stop()
self._bulk_load_performed = True
else:
raise ValueError('bulk_loader not set')
else:
raise ValueError('bulk_load_from_cache and in_bulk_mode = false')
[docs]
def commit(
self,
stat_name: str = 'commit',
parent_stats: Optional[Statistics] = None,
print_to_log: bool = True,
connection_name: Optional[str] = None,
begin_new: bool = True,
):
"""
Flush any buffered deletes, updates, or inserts
Parameters
----------
stat_name: str
Name of this step for the ETLTask statistics. Default = 'commit'
parent_stats: bi_etl.statistics.Statistics
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
print_to_log: bool
Should this add a debug log entry for each commit. Defaults to true.
connection_name:
Name of the pooled connection to use
Defaults to DEFAULT_CONNECTION_NAME
begin_new:
Start a new transaction after commit
"""
if self.in_bulk_mode:
self.log.debug(
f"{self}.commit() does nothing in bulk mode. "
f"{self}.bulk_load_from_cache() or {self}.close() will load the pending data."
)
# Commit is not final enough to perform the bulk load.
else:
# insert_pending_batch calls other *_pending methods
self._insert_pending_batch()
stats = self.get_unique_stats_entry(stat_name, parent_stats=parent_stats)
# Start & stop appropriate timers for each
stats['commit count'] += 1
stats.timer.start()
if connection_name is not None:
self.log.debug(f"{self} commit on specified connection {connection_name}")
self.database.commit(connection_name)
if begin_new:
self.begin(connection_name=connection_name)
else:
# We need to close these in the correct order to prevent deadlocks
commit_order = [
# Read only
'select',
'max',
'get_one',
# Write
'default', # Used for inserts
'sql_upsert',
'singly', # Really should not be open, it gets rolled back
# Any others that exist will be commited last
]
self.log.info(
f"{self} commit on all active connections "
"(if this deadlocks, provide explicit connection names "
"to commit in the correct order)."
)
for possible_connection_name in commit_order:
if possible_connection_name in self._connections_used:
self.database.commit(possible_connection_name)
if begin_new:
self.begin(connection_name=possible_connection_name)
for used_connection_name in self._connections_used:
if used_connection_name not in commit_order:
self.database.commit(used_connection_name)
if begin_new:
self.begin(connection_name=used_connection_name)
self._connections_used = set()
stats.timer.stop()
[docs]
def rollback(
self,
stat_name: str = 'rollback',
parent_stats: Optional[Statistics] = None,
connection_name: Optional[str] = None,
begin_new: bool = True,
):
"""
Rollback any uncommitted deletes, updates, or inserts.
Parameters
----------
stat_name: str
Name of this step for the ETLTask statistics. Default = 'rollback'
parent_stats: bi_etl.statistics.Statistics
Optional Statistics object to nest this steps statistics in.
Default is to place statistics in the ETLTask level statistics.
connection_name:
Name of the connection to rollback
begin_new:
Start a new transaction after rollback
"""
if self.in_bulk_mode:
self._bulk_rollback = True
else:
self.log.debug("Rolling back transaction")
stats = self.get_unique_stats_entry(stat_name, parent_stats=parent_stats)
stats['calls'] += 1
stats.timer.start()
if connection_name is not None:
self.log.debug(f"{self} rollback on specified connection {connection_name}")
self.database.rollback(connection_name)
self._connections_used.remove(connection_name)
else:
self.log.debug(f"{self} rollback on all active connections {self._connections_used}")
for connection_name in self._connections_used:
self.database.rollback(connection_name)
self._connections_used = set()
if begin_new:
self.begin(connection_name=connection_name)
stats.timer.stop()
@staticmethod
def _sql_column_not_equals(column_name, alias_1='e', alias_2='s'):
return "(" \
f"( {alias_1}.{column_name} != {alias_2}.{column_name} )" \
f" OR ({alias_1}.{column_name} IS NULL AND {alias_2}.{column_name} IS NOT NULL)" \
f" OR ({alias_1}.{column_name} IS NOT NULL AND {alias_2}.{column_name} IS NULL)" \
")"
def _safe_temp_table_name(self, proposed_name):
max_identifier_length = self.database.bind.dialect.max_identifier_length
if len(proposed_name) > max_identifier_length:
alpha_code = int2base(abs(hash(proposed_name)), 36)
proposed_name = f"tmp{alpha_code}"
if len(proposed_name) > max_identifier_length:
proposed_name = proposed_name[:max_identifier_length]
return proposed_name
@staticmethod
def _sql_indent_join(
statement_list: Iterable[str],
indent: int,
prefix: str = '',
suffix: str = '',
add_newline: bool = True,
prefix_if_not_empty: str = '',
) -> str:
if add_newline:
newline = '\n'
else:
newline = ''
statement_list = list(statement_list)
if len(statement_list) == 0:
prefix_if_not_empty = ''
return prefix_if_not_empty + f"{suffix}{newline}{' ' * indent}{prefix}".join(statement_list)
def _sql_primary_key_name(self) -> str:
if len(self.primary_key) != 1:
raise ValueError(f"{self}.upsert_db_exclusive requires single primary_key column. Got {self.primary_key}")
return self.primary_key[0]
def _sql_now(self) -> str:
return "CURRENT_TIMESTAMP"
def _sql_add_timedelta(
self,
column_name: str,
delta: timedelta,
) -> str:
database_type = self.database.dialect_name
if database_type in {'oracle'}:
if delta.seconds != 0 or delta.microseconds != 0:
return f"{column_name} + INTERVAL '{delta.total_seconds()}' SECOND"
elif delta.days != 0:
return f"{column_name} + {delta.days}"
elif database_type in {'postgresql'}:
if delta.seconds != 0 or delta.microseconds != 0:
return f"{column_name} + INTERVAL '{delta.total_seconds()} SECONDS'"
elif delta.days != 0:
return f"{column_name} + INTERVAL '{delta.days} DAY'"
elif database_type in {'mysql'}:
if delta.seconds != 0:
return f"DATE_ADD({column_name}, INTERVAL {delta.total_seconds()} SECOND)"
elif delta.microseconds != 0:
return f"DATE_ADD({column_name}, INTERVAL {delta.microseconds} MICROSECOND)"
elif delta.days != 0:
return f"DATE_ADD({column_name}, INTERVAL {delta.days} DAY)"
elif database_type in {'sqlite'}:
if delta.seconds != 0 or delta.microseconds != 0:
return f"datetime({column_name}, '{delta.total_seconds()} SECONDS')"
elif delta.days != 0:
return f"datetime({column_name}, '{delta.days:+d} DAYS')"
else:
if delta.seconds != 0 or delta.microseconds != 0:
return f"DATEADD(second, {delta.total_seconds()}, {column_name})"
elif delta.days != 0:
return f"DATEADD(day, {delta.days}, {column_name})"
raise ValueError(f"No _sql_add_timedelta defined for DB {database_type} and timedelta {delta}")
def _sql_date_literal(
self,
dt: datetime,
):
database_type = self.database.dialect_name
if database_type in {'oracle'}:
return f"TIMESTAMP '{dt.isoformat(sep=' ')}'"
elif database_type in {'postgresql', 'redshift'}:
return f"'{dt.isoformat()}'::TIMESTAMP"
elif database_type in {'sqlite'}:
return f"datetime('{dt.isoformat()}')"
else:
return f"TIMESTAMP('{dt.isoformat()}')"
def _sql_insert_new(
self,
source_table: ReadOnlyTable,
matching_columns: Iterable,
effective_date: Optional[datetime] = None,
connection_name: str = 'sql_upsert',
):
if not self.auto_generate_key:
raise ValueError(f"_sql_insert_new expects to only be used with auto_generate_key = True")
now_str = self.get_current_time().isoformat()
insert_new_sql = f"""
INSERT INTO {self.qualified_table_name} (
{self._sql_primary_key_name()},
{self.delete_flag},
{self.last_update_date},
{Table._sql_indent_join(matching_columns, 24, suffix=',')}
)
WITH max_srgt AS (SELECT coalesce(max({self._sql_primary_key_name()}),0) as max_srgt FROM {self.qualified_table_name})
SELECT
max_srgt.max_srgt + ROW_NUMBER() OVER (order by 1) as {self._sql_primary_key_name()},
'{self.delete_flag_no}' as {self.delete_flag},
'{now_str}' as {self.last_update_date},
{Table._sql_indent_join(matching_columns, 16, suffix=',')}
FROM max_srgt
CROSS JOIN
{source_table.qualified_table_name} s
WHERE NOT EXISTS(
SELECT 1
FROM {self.qualified_table_name} e
WHERE {Table._sql_indent_join([f"e.{nk_col} = s.{nk_col}" for nk_col in self.natural_key], 18, prefix='AND ')}
)
"""
self.log.debug('=' * 80)
self.log.debug('insert_new_sql')
self.log.debug('=' * 80)
self.log.debug(insert_new_sql)
sql_timer = Timer()
results = self.execute(insert_new_sql, connection_name=connection_name)
self.log.debug(f"Impacted rows = {results.rowcount} (meaningful count? {results.supports_sane_rowcount()})")
self.log.debug(f"Execution time ={sql_timer.seconds_elapsed_formatted}")
def _sql_update_from(
self,
update_name: str,
source_sql: str,
list_of_joins: List[Tuple[str, Tuple[str, str]]],
list_of_sets: List[Tuple[str, str]],
target_alias_in_source: str = 'tgt',
extra_where: Optional[str] = None,
connection_name: str = 'sql_upsert',
):
# Multiple Table Updates
# https://docs.sqlalchemy.org/en/14/core/tutorial.html#multiple-table-updates
conn = self.connection(connection_name)
database_type = self.database.dialect_name
if extra_where is not None:
and_extra_where = f"AND {extra_where}"
else:
extra_where = ''
and_extra_where = ''
sql = None
if database_type in {'oracle'}:
sql = f"""
MERGE INTO {self.qualified_table_name} tgt
USING (
SELECT
{Table._sql_indent_join([f"{join_entry[1][0]}.{join_entry[1][1]}" for join_entry in list_of_joins], 24, suffix=',')},
{Table._sql_indent_join([f"{set_entry[1]} as col_{set_num}" for set_num, set_entry in enumerate(list_of_sets)], 16, suffix=',')}
FROM {source_sql}
) src
ON (
{Table._sql_indent_join([f"tgt.{join_entry[0]} = src.{join_entry[1][1]}" for join_entry in list_of_joins], 18, prefix='AND ')}
)
WHEN MATCHED THEN UPDATE
SET {Table._sql_indent_join([f"tgt.{set_entry[0]} = src.col_{set_num}" for set_num, set_entry in enumerate(list_of_sets)], 16, suffix=',')}
"""
elif database_type in {'mysql'}:
sql = f"""
UPDATE {self.qualified_table_name} INNER JOIN {source_sql}
ON {Table._sql_indent_join([f"{self.qualified_table_name}.{join_entry[0]} = {'.'.join(join_entry[1])}" for join_entry in list_of_joins], 19, prefix='AND ')}
SET {Table._sql_indent_join([f"{target_alias_in_source}.{set_entry[0]} = {set_entry[1]}" for set_entry in list_of_sets], 16, suffix=',')}
"""
# SQL Server
elif database_type in {'mssql'}:
sql = f"""
UPDATE {self.qualified_table_name}
SET {Table._sql_indent_join([f"{set_entry[0]} = {set_entry[1]}" for set_entry in list_of_sets], 16, suffix=',')}
FROM {self.qualified_table_name}
INNER JOIN
{source_sql}
ON {Table._sql_indent_join([f"{self.qualified_table_name}.{join_entry[0]} = {'.'.join(join_entry[1])}" for join_entry in list_of_joins], 24, prefix='AND ')}
{and_extra_where}
"""
elif database_type == 'sqlite':
import sqlite3
from packaging import version
if version.parse(sqlite3.sqlite_version) > version.parse('3.33.0'):
# Use PostgreSQL below
pass
else:
# Older
set_statement_list = []
where_stmt = f"""
{Table._sql_indent_join([f"{self.qualified_table_name}.{join_entry[0]} = {join_entry[1][0]}.{join_entry[1][1]}" for join_entry in list_of_joins], 18, prefix='AND ')}
{and_extra_where}
"""
for set_entry in list_of_sets:
select_part = f"""
SELECT {set_entry[1]}
FROM {source_sql}
WHERE {where_stmt}
"""
set_statement_list.append(f"{set_entry[0]} = ({select_part})")
set_delimiter = ',\n'
sql = f"""
UPDATE {self.qualified_table_name} SET {set_delimiter.join(set_statement_list)}
WHERE EXISTS ({select_part})
"""
if sql is None: # SQLite 3.33+ and PostgreSQL
# https://sqlite.org/lang_update.html
sql = f"""
UPDATE {self.qualified_table_name}
SET {Table._sql_indent_join([f"{set_entry[0]} = {set_entry[1]}" for set_entry in list_of_sets], 16, suffix=',')}
FROM {source_sql}
WHERE {Table._sql_indent_join([f"{self.qualified_table_name}.{join_entry[0]} = {join_entry[1][0]}.{join_entry[1][1]}" for join_entry in list_of_joins], 18, prefix='AND ')}
{and_extra_where}
"""
sql_timer = Timer()
self.log.debug(f"Running {update_name}")
self.log.debug(sql)
sql = sqlalchemy.text(sql)
results = self.execute(sql, connection_name=connection_name)
self.log.debug(f"Impacted rows = {results.rowcount:,} (meaningful count? {results.supports_sane_rowcount()})")
self.log.debug(f"Execution time ={sql_timer.seconds_elapsed_formatted}")