Source code for zcloud.console_scripts.data_puller_client

import asyncio
from enum import Enum
import logging
import os
import time

import aiohttp
import click
from google.cloud import bigquery, storage
import socketio

logger = logging.getLogger(__name__)
# Constants
WEBSOCKET_URL = "wss://your-cloud-run-instance-url"
TIMEOUT = 60  # Timeout in seconds
BQ_TABLE_ID = "structure-standard-table-id"
ENTITY_ID_COLUMN = "bcif_id"
BCIF_ID_COLUMN = "bcif_id"
AVRO_FILE_PATH_COLUMN = "avro_file_path"
FIND_STRUCTURE_ENDPOINT = "https://your-cloud-run-instance-url/find_structure"

MESSAGE_ENTITY_ID_KEY = "entityIDs"
MESSAGE_CORRELATION_ID_KEY = "correlationIDs"

[docs] def split_uri_to_bucket_and_path(uri): split_uri = uri.split("/") bucket = split_uri[2] path = "/".join(split_uri[3:]) return bucket, path
[docs] async def download_blob(bucket, file_path,target_dir="."): gcs_client = storage.Client() bucket_client = gcs_client.bucket(bucket) blob = bucket_client.blob(file_path) await blob.download_to_file(os.path.join(target_dir,file_path))
[docs] async def join_get_bq_records_from_list(id_list,join_key_temp_table,join_key_bq,bq_table_id,columns_to_select): # BQ query query = f""" CREATE TEMP TABLE temp_table AS SELECT {join_key_temp_table} FROM UNNEST({id_list}) AS {join_key_temp_table} SELECT {','.join(columns_to_select)} FROM `{bq_table_id}` JOIN temp_table ON `{bq_table_id}`.{join_key_bq} = temp_table.{join_key_bq} """ # spin up a bq client bq_client = bigquery.Client() # Execute the query query_job = bq_client.query(query) await query_job.result()
# Notification type enum, this should probably come from a centralized schema registry. Hardcoded here as a placeholder, replace during integration
[docs] class NotificationType(Enum): CLIENT_ID_ASSIGNMENT = "client_id_assignment"
[docs] class MessageType(Enum): # This should probably come from a centralized schema registry. Hardcoded here as a placeholder NOTIFICATION = "notification" STATUS = "status" STRUCTURE_READY = "structure_ready" REQUEST_CLIENT_ID = "request_client_id"
[docs] class FakeMessageSchemaHandler: # real handler is a factory that finds the correct handler based on the message type, and has the appropriate keys not hardcoded def __init__(self, event_name, message): self._event_name = event_name self._message = message self.notification_type = NotificationType.CLIENT_ID_ASSIGNMENT self.status_type = None self.entity_ids = MESSAGE_ENTITY_ID_KEY self.correlation_ids = MESSAGE_CORRELATION_ID_KEY
[docs] class ClientState: def __init__(self, entity_ids, correlation_ids,target_dir="."): self.client_id = None self.task_queue = asyncio.Queue() self.message_queue = asyncio.Queue() # Queue for incoming messages self._start_time = time.time() self._target_dir = target_dir # download dir self._entity_ids = entity_ids self._correlation_ids = correlation_ids self._message_handlers = { MessageType.NOTIFICATION: self.process_notification, MessageType.STATUS: self.process_status, MessageType.STRUCTURE_READY: self.process_structure_ready, } def _get_handler(self,message_type,message): return FakeMessageSchemaHandler(message_type,message)
[docs] async def set_client_id(self, client_id): self.client_id = client_id
[docs] async def get_client_id(self): while self.client_id is None: await asyncio.sleep(0.1) # Wait for client_id to be set return self.client_id
# message dispatching logic
[docs] async def process_message(self, event_name, message): # asynnc because bq/gcs are async, the other functions just block, and await doesn't actually yield control """""" self_func_key = self._message_handlers.get(event_name) if self_func_key is None: logger.warning(f"No handler found for event name: {event_name}") return else: await self._message_handlers[event_name](message)
[docs] def process_notification(self, message): logger.info(f"Processing notification: {message}") handler = self._get_handler(MessageType.NOTIFICATION, message) # The only notification(s) we care about is client id assignment if handler.notification_type == NotificationType.CLIENT_ID_ASSIGNMENT: self.set_client_id(message[handler.client_id]) else: logger.warning(f"Uknown notification type: {handler.notification_type}") logger.info(f"Message: {message}")
[docs] def process_status(self, message): logger.info(f"Processing status: {message}") handler = self._get_handler(MessageType.STATUS, message) # Just read the statuses out to the console, nothing implemented yet logger.info(f"Status: {handler.status_type}") details = message.get(handler.details) if details: logger.info(f"Details: {details}")
[docs] async def process_structure_ready( self, message ): # Async because we have to ping BQ and then check GCS, and then download files logger.info(f"Processing structure ready: {message}") handler = self._get_handler(MessageType.STRUCTURE_READY, message) # get entity ids from the message entity_ids = message.get(handler.entity_ids) if entity_ids: logger.info(f"Entity IDs: {entity_ids}") else: logger.warning(f"No entity IDs found in structure ready message") # Make a session table with the entity ids and join it to the structure standard table on the entity id # select the columns with bcif_id and avro_file_path columns_to_select = [BCIF_ID_COLUMN, AVRO_FILE_PATH_COLUMN] results = await join_get_bq_records_from_list(entity_ids,ENTITY_ID_COLUMN,ENTITY_ID_COLUMN,BQ_TABLE_ID,columns_to_select) logger.info(f"Results: {results}") # reorder the data to drop redundant refs to avro files/ buckets avro_data_file_dict = {} for result in results: bucket, file_path = split_uri_to_bucket_and_path(result[AVRO_FILE_PATH_COLUMN]) target = avro_data_file_dict.get(bucket) if target is None: target = [] avro_data_file_dict[bucket] = target target.append(file_path) unique_buckets = avro_data_file_dict.keys() logger.info(f"Unique buckets: {unique_buckets}") for bucket in unique_buckets: await asyncio.gather(*[download_blob(bucket, file_path,self._target_dir) for file_path in avro_data_file_dict[bucket]])
#TODO: decide how the entity id (now and corresponding bcif id) should be mapped to the avro file (for cases where there are many records in the avro file and we only want some of them)
[docs] def get_entity_ids(self): return self._entity_ids
[docs] def get_correlation_ids(self): return self._correlation_ids
[docs] def get_start_time(self): return self._start_time
[docs] async def on_reconnect(sio, state, initial_connection=False): FAKE_GET_CLIENT_ID_KEY_FROM_HANDLER = "clientID" if initial_connection: logger.info("Connected to server, requesting client ID") await sio.emit(MessageType.REQUEST_CLIENT_ID) else: logger.info(f"Reconnecting to client ID: {client_id}") client_id = await state.get_client_id() await sio.emit( MessageType.REQUEST_CLIENT_ID, {FAKE_GET_CLIENT_ID_KEY_FROM_HANDLER: client_id}, )
[docs] async def connect_socketio(state): sio = socketio.AsyncClient( reconnection=True, reconnection_attempts=5, reconnection_delay=1, reconnection_delay_max=5, ) async def on_connect(): print("Connected to server") await on_reconnect( sio, state, initial_connection=state.client_id is None ) # do NOT use the async getter here async def on_disconnect(): print("Disconnected from server") async def on_message(event_name, data): await state.message_queue.put((event_name, data)) sio.on("connect", on_connect) sio.on("disconnect", on_disconnect) sio.on("*", on_message) await sio.connect(WEBSOCKET_URL) await sio.wait()
[docs] async def send_request(client_id, entity_ids, correlation_ids): """ Send HTTP POST to find_structure_endpoint """ url = FIND_STRUCTURE_ENDPOINT data = { "clientID": client_id, "entityIDs": entity_ids, "correlationIDs": correlation_ids, } async with aiohttp.ClientSession() as session: async with session.post(url, json=data) as response: return await response.json()
[docs] async def process_messages(state): while time.time() - state.get_start_time() < TIMEOUT: event_name, message = await state.message_queue.get() try: await state.process_message(event_name, message) except Exception as e: print(f"Error processing message: {e}") finally: state.message_queue.task_done()
[docs] async def client_loop(entity_ids, correlation_ids): state = ClientState(entity_ids, correlation_ids) await asyncio.gather( connect_socketio(state), send_request( state.get_client_id(), state.get_entity_ids(), state.get_correlation_ids() ), process_messages(state), )
@click.command() # TODO: Incorrect, should be parsing a file, leaving as is for linting @click.option("--entity_ids", type=str, help="Entity IDs") @click.option("--correlation_ids", type=str, help="Correlation IDs") @click.option("--debug", is_flag=True,default=False, help="Enable debug logging") def main(entity_ids, correlation_ids, debug): logging.basicConfig( level=getattr( logging, ("DEBUG" if debug else os.getenv("PYTHON_LOG_LEVEL", "INFO")).upper(), logging.INFO), format="%(levelname)s %(name)s: %(message)s",) asyncio.run(client_loop(entity_ids, correlation_ids)) if __name__ == "__main__": main()