import asyncio
from enum import Enum
import logging
import os
import time
import aiohttp
import click
from google.cloud import bigquery, storage
import socketio
logger = logging.getLogger(__name__)
# Constants
WEBSOCKET_URL = "wss://your-cloud-run-instance-url"
TIMEOUT = 60 # Timeout in seconds
BQ_TABLE_ID = "structure-standard-table-id"
ENTITY_ID_COLUMN = "bcif_id"
BCIF_ID_COLUMN = "bcif_id"
AVRO_FILE_PATH_COLUMN = "avro_file_path"
FIND_STRUCTURE_ENDPOINT = "https://your-cloud-run-instance-url/find_structure"
MESSAGE_ENTITY_ID_KEY = "entityIDs"
MESSAGE_CORRELATION_ID_KEY = "correlationIDs"
[docs]
def split_uri_to_bucket_and_path(uri):
split_uri = uri.split("/")
bucket = split_uri[2]
path = "/".join(split_uri[3:])
return bucket, path
[docs]
async def download_blob(bucket, file_path,target_dir="."):
gcs_client = storage.Client()
bucket_client = gcs_client.bucket(bucket)
blob = bucket_client.blob(file_path)
await blob.download_to_file(os.path.join(target_dir,file_path))
[docs]
async def join_get_bq_records_from_list(id_list,join_key_temp_table,join_key_bq,bq_table_id,columns_to_select):
# BQ query
query = f"""
CREATE TEMP TABLE temp_table AS
SELECT {join_key_temp_table}
FROM UNNEST({id_list}) AS {join_key_temp_table}
SELECT
{','.join(columns_to_select)}
FROM `{bq_table_id}`
JOIN temp_table ON `{bq_table_id}`.{join_key_bq} = temp_table.{join_key_bq}
"""
# spin up a bq client
bq_client = bigquery.Client()
# Execute the query
query_job = bq_client.query(query)
await query_job.result()
# Notification type enum, this should probably come from a centralized schema registry. Hardcoded here as a placeholder, replace during integration
[docs]
class NotificationType(Enum):
CLIENT_ID_ASSIGNMENT = "client_id_assignment"
[docs]
class MessageType(Enum): # This should probably come from a centralized schema registry. Hardcoded here as a placeholder
NOTIFICATION = "notification"
STATUS = "status"
STRUCTURE_READY = "structure_ready"
REQUEST_CLIENT_ID = "request_client_id"
[docs]
class FakeMessageSchemaHandler: # real handler is a factory that finds the correct handler based on the message type, and has the appropriate keys not hardcoded
def __init__(self, event_name, message):
self._event_name = event_name
self._message = message
self.notification_type = NotificationType.CLIENT_ID_ASSIGNMENT
self.status_type = None
self.entity_ids = MESSAGE_ENTITY_ID_KEY
self.correlation_ids = MESSAGE_CORRELATION_ID_KEY
[docs]
class ClientState:
def __init__(self, entity_ids, correlation_ids,target_dir="."):
self.client_id = None
self.task_queue = asyncio.Queue()
self.message_queue = asyncio.Queue() # Queue for incoming messages
self._start_time = time.time()
self._target_dir = target_dir # download dir
self._entity_ids = entity_ids
self._correlation_ids = correlation_ids
self._message_handlers = {
MessageType.NOTIFICATION: self.process_notification,
MessageType.STATUS: self.process_status,
MessageType.STRUCTURE_READY: self.process_structure_ready,
}
def _get_handler(self,message_type,message):
return FakeMessageSchemaHandler(message_type,message)
[docs]
async def set_client_id(self, client_id):
self.client_id = client_id
[docs]
async def get_client_id(self):
while self.client_id is None:
await asyncio.sleep(0.1) # Wait for client_id to be set
return self.client_id
# message dispatching logic
[docs]
async def process_message(self, event_name, message): # asynnc because bq/gcs are async, the other functions just block, and await doesn't actually yield control
""""""
self_func_key = self._message_handlers.get(event_name)
if self_func_key is None:
logger.warning(f"No handler found for event name: {event_name}")
return
else:
await self._message_handlers[event_name](message)
[docs]
def process_notification(self, message):
logger.info(f"Processing notification: {message}")
handler = self._get_handler(MessageType.NOTIFICATION, message)
# The only notification(s) we care about is client id assignment
if handler.notification_type == NotificationType.CLIENT_ID_ASSIGNMENT:
self.set_client_id(message[handler.client_id])
else:
logger.warning(f"Uknown notification type: {handler.notification_type}")
logger.info(f"Message: {message}")
[docs]
def process_status(self, message):
logger.info(f"Processing status: {message}")
handler = self._get_handler(MessageType.STATUS, message)
# Just read the statuses out to the console, nothing implemented yet
logger.info(f"Status: {handler.status_type}")
details = message.get(handler.details)
if details:
logger.info(f"Details: {details}")
[docs]
async def process_structure_ready(
self, message
): # Async because we have to ping BQ and then check GCS, and then download files
logger.info(f"Processing structure ready: {message}")
handler = self._get_handler(MessageType.STRUCTURE_READY, message)
# get entity ids from the message
entity_ids = message.get(handler.entity_ids)
if entity_ids:
logger.info(f"Entity IDs: {entity_ids}")
else:
logger.warning(f"No entity IDs found in structure ready message")
# Make a session table with the entity ids and join it to the structure standard table on the entity id
# select the columns with bcif_id and avro_file_path
columns_to_select = [BCIF_ID_COLUMN, AVRO_FILE_PATH_COLUMN]
results = await join_get_bq_records_from_list(entity_ids,ENTITY_ID_COLUMN,ENTITY_ID_COLUMN,BQ_TABLE_ID,columns_to_select)
logger.info(f"Results: {results}")
# reorder the data to drop redundant refs to avro files/ buckets
avro_data_file_dict = {}
for result in results:
bucket, file_path = split_uri_to_bucket_and_path(result[AVRO_FILE_PATH_COLUMN])
target = avro_data_file_dict.get(bucket)
if target is None:
target = []
avro_data_file_dict[bucket] = target
target.append(file_path)
unique_buckets = avro_data_file_dict.keys()
logger.info(f"Unique buckets: {unique_buckets}")
for bucket in unique_buckets:
await asyncio.gather(*[download_blob(bucket, file_path,self._target_dir) for file_path in avro_data_file_dict[bucket]])
#TODO: decide how the entity id (now and corresponding bcif id) should be mapped to the avro file (for cases where there are many records in the avro file and we only want some of them)
[docs]
def get_entity_ids(self):
return self._entity_ids
[docs]
def get_correlation_ids(self):
return self._correlation_ids
[docs]
def get_start_time(self):
return self._start_time
[docs]
async def on_reconnect(sio, state, initial_connection=False):
FAKE_GET_CLIENT_ID_KEY_FROM_HANDLER = "clientID"
if initial_connection:
logger.info("Connected to server, requesting client ID")
await sio.emit(MessageType.REQUEST_CLIENT_ID)
else:
logger.info(f"Reconnecting to client ID: {client_id}")
client_id = await state.get_client_id()
await sio.emit(
MessageType.REQUEST_CLIENT_ID,
{FAKE_GET_CLIENT_ID_KEY_FROM_HANDLER: client_id},
)
[docs]
async def connect_socketio(state):
sio = socketio.AsyncClient(
reconnection=True,
reconnection_attempts=5,
reconnection_delay=1,
reconnection_delay_max=5,
)
async def on_connect():
print("Connected to server")
await on_reconnect(
sio, state, initial_connection=state.client_id is None
) # do NOT use the async getter here
async def on_disconnect():
print("Disconnected from server")
async def on_message(event_name, data):
await state.message_queue.put((event_name, data))
sio.on("connect", on_connect)
sio.on("disconnect", on_disconnect)
sio.on("*", on_message)
await sio.connect(WEBSOCKET_URL)
await sio.wait()
[docs]
async def send_request(client_id, entity_ids, correlation_ids):
"""
Send HTTP POST to find_structure_endpoint
"""
url = FIND_STRUCTURE_ENDPOINT
data = {
"clientID": client_id,
"entityIDs": entity_ids,
"correlationIDs": correlation_ids,
}
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as response:
return await response.json()
[docs]
async def process_messages(state):
while time.time() - state.get_start_time() < TIMEOUT:
event_name, message = await state.message_queue.get()
try:
await state.process_message(event_name, message)
except Exception as e:
print(f"Error processing message: {e}")
finally:
state.message_queue.task_done()
[docs]
async def client_loop(entity_ids, correlation_ids):
state = ClientState(entity_ids, correlation_ids)
await asyncio.gather(
connect_socketio(state),
send_request(
state.get_client_id(), state.get_entity_ids(), state.get_correlation_ids()
),
process_messages(state),
)
@click.command()
# TODO: Incorrect, should be parsing a file, leaving as is for linting
@click.option("--entity_ids", type=str, help="Entity IDs")
@click.option("--correlation_ids", type=str, help="Correlation IDs")
@click.option("--debug", is_flag=True,default=False, help="Enable debug logging")
def main(entity_ids, correlation_ids, debug):
logging.basicConfig(
level=getattr(
logging,
("DEBUG" if debug else os.getenv("PYTHON_LOG_LEVEL", "INFO")).upper(),
logging.INFO),
format="%(levelname)s %(name)s: %(message)s",)
asyncio.run(client_loop(entity_ids, correlation_ids))
if __name__ == "__main__":
main()