Source code for bi_etl.bulk_loaders.redshift_s3_avro_loader

# https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

import os.path
import os.path
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from tempfile import TemporaryDirectory
from typing import *

import fastavro

from bi_etl.bulk_loaders.redshift_s3_base import RedShiftS3Base
from bi_etl.bulk_loaders.s3_bulk_load_config import S3_Bulk_Loader_Config

if TYPE_CHECKING:
    from bi_etl.components.table import Table
    from bi_etl.scheduler.task import ETLTask


[docs] class RedShiftS3AvroBulk(RedShiftS3Base):
[docs] def __init__( self, config: S3_Bulk_Loader_Config, ): super().__init__( config=config, ) if self.s3_file_max_rows is None and self.s3_files_to_generate is None: self.s3_file_max_rows = 50000 # Redshift appears to count the schema as a line read self.lines_scanned_modifier = -1 try: import snappy self.codec = 'snappy' except ImportError: self.codec = 'deflate'
@property def needs_all_columns(self): return False
[docs] def get_copy_sql( self, s3_source_path: str, table_to_load: str, file_compression: str = '', analyze_compression: str = None, options: str = '', ): analyze_compression = analyze_compression or self.analyze_compression if analyze_compression: options += f' COMPUPDATE {self.analyze_compression} ' return f"""\ COPY {table_to_load} FROM 's3://{self.s3_bucket_name}/{s3_source_path}' credentials 'aws_access_key_id={self.s3_user_id};aws_secret_access_key={self.s3_password}' AVRO 'auto' {file_compression} {options} """
def _write_to_file( self, temp_dir, file_number: int, parsed_schema, data_chunk: list, local_files_list: list, ): self.log.debug(f"Starting bulk load chunk {file_number}") filepath = os.path.join(temp_dir, f'data_{file_number}.avro') local_files_list.append(filepath) with open(filepath, 'wb') as avro_file: # avro_iterators = self.distribute(iterator, writer_pool_size) fastavro.writer(avro_file, parsed_schema, data_chunk, codec=self.codec) self.log.debug(f"Closed bulk load chunk {file_number}: {filepath}")
[docs] def load_from_iterable( self, iterable: Iterable, table_object: Table, table_to_load: str = None, perform_rename: bool = False, progress_frequency: int = 10, analyze_compression: str = None, parent_task: Optional[ETLTask] = None, ) -> int: table_to_load = table_to_load or table_object.qualified_table_name with TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir: # Generate schema from table fields = list() type_map = { str: ['string'], bytes: ['bytes'], date: ['string'], datetime: ['string'], time: ['string'], timedelta: ['string'], bool: ['boolean'], int: ['long'], float: ['double'], Decimal: ['int', 'double'], } for column in table_object.columns: if column.type.python_type in type_map: col_type = type_map[column.type.python_type] else: col_type = ['string'] fields.append({ 'name': column.name, 'type': ["null"] + col_type, 'default': None, }) schema = { "name": table_to_load, "type": "record", 'fields': fields, } self.log.debug(f"schema = {schema}") parsed_schema = fastavro.parse_schema(schema) self.log.debug(f"parsed_schema = {parsed_schema}") file_number = 0 data_length = 0 local_files_list = list() # Note: Tried using this method to chunk the list, but it didn't work with the # gevent based threading used by insert_row with bulk loader # for data_chunk in more_itertools.chunked(iterator, self.s3_file_max_rows): data_chunk = list() for row in iterable: data_chunk.append(row) if len(data_chunk) >= self.s3_file_max_rows: data_length += len(data_chunk) file_number += 1 self._write_to_file( temp_dir=temp_dir, file_number=file_number, parsed_schema=parsed_schema, data_chunk=data_chunk, local_files_list=local_files_list, ) data_chunk = list() # Write final chunk if len(data_chunk) >= 0: data_length += len(data_chunk) file_number += 1 self._write_to_file( temp_dir=temp_dir, file_number=file_number, parsed_schema=parsed_schema, data_chunk=data_chunk, local_files_list=local_files_list, ) self.log.debug(f"Loading from files {local_files_list}") rows_loaded = self.load_from_files( local_files_list, table_object=table_object, table_to_load=table_to_load, perform_rename=perform_rename, analyze_compression=analyze_compression, ) if rows_loaded != data_length: self.log.error(f"COPY from files should have loaded {data_length:,} but it reports {rows_loaded:,} rows loaded") else: self.log.info(f"{self} had nothing to do with 0 rows found") return data_length