Audit Logs ETL - Databricks
Audit Logs ETL(Python)

Remove all metastore entries and files, if re-running

%sql
DROP DATABASE IF EXISTS audit_logs CASCADE
%fs rm -r /tmp/audit_logs_example

Start off by importing the required libraries, which should be included in the most recent Databricks runtimes

from pyspark.sql.functions import udf, col, from_unixtime, from_utc_timestamp, from_json
from pyspark.sql.types import StringType, StructField, StructType
import json, time, requests

You have two options for using this notebook:

  • process your own Databricks audit logs by inputting the prefix where Databricks delivers them (select s3bucket in the Data Source widget and input the proper prefix to Audit Logs Source S3 bucket widget)
  • utilize generated data based on the schema of real Databricks audit logs (select fakeData in the Data Source widget)
if dbutils.widgets.getArgument("Data Source") == 'fakeData':
  rawBytes = requests.get("https://raw.githubusercontent.com/craigng/audit_logs_test/master/test.txt") # get the contents of test.txt using HTTPS and the requests library
  _list = rawBytes.text.split("|") # split the pipe-delimited file into a list of JSON objects
  listAsString = json.dumps(_list) # parse the list of JSON objects to a single string
  jsonList = json.loads(listAsString) # load the string as JSON
  _rdd = sc.parallelize(jsonList) # parse the JSON into an RDD
  spark.read.json(_rdd).write.mode("overwrite").partitionBy("date").json("/tmp/audit_logs_example/raw_data") # write files to DBFS to use as source bucket
  sourceBucket = "/tmp/audit_logs_example/raw_data"
  sinkBucket = "/tmp/audit_logs_example" # save to DBFS
else:
  sourceBucket = dbutils.widgets.getArgument("Audit Logs Source S3 bucket")
  sinkBucket = dbutils.widgets.getArgument("Delta Lake Sink S3 bucket")

Databricks delivers audit logs daily to a customer-specified S3 bucket in the form of JSON. Rather than writing logic to determine the state of our Delta Lake tables, we're going to utilize Structured Streaming's write-ahead logs and checkpoints to maintain the state of our tables. In this case, we've designed our ETL to run once per day, so we're using a file source with triggerOnce to simulate a batch workload with a streaming framework. Since Structured Streaming requires that we explicitly define the schema, we'll read the raw JSON files to get it.

streamSchema = spark.read.json(sourceBucket).schema

We instantiate our StreamReader using the schema we inferred and the path to the raw audit logs.

streamDF = (
  spark
  .readStream
  .format("json")
  .schema(streamSchema)
  .load(sourceBucket)
)

We then instantiate our StreamWriter and write out the raw audit logs into a bronze Delta Lake table that's partitioned by date.

(streamDF
 .writeStream
 .format("delta")
 .partitionBy("date")
 .outputMode("append")
 .option("checkpointLocation", "{}/checkpoints/bronze".format(sinkBucket))
 .option("path", "{}/streaming/bronze".format(sinkBucket))
 .option("mergeSchema", True)
 .trigger(once=True)
 .start()
)

Since the stream runs in a separate thread from the other processes, we want to ensure that the query finishes before we proceed. We can use information available via the SparkSession (specifically spark.streams.active) to check before proceeding.

while spark.streams.active != []:
  print("Waiting for streaming query to finish.")
  time.sleep(5)

Now that we've created the table on an external S3 bucket, we'll need to register the table to the internal Databricks Hive metastore to make access to the data easier for end users. We'll create both the database/schema audit_logs, in addition to the bronze table.

%sql
CREATE DATABASE IF NOT EXISTS audit_logs
spark.sql("""
CREATE TABLE IF NOT EXISTS audit_logs.bronze
USING DELTA
LOCATION '{}/streaming/bronze'
""".format(sinkBucket))
%sql
OPTIMIZE audit_logs.bronze

If you update your Delta Lake tables in batch or pseudo-batch fashion, it's best practice to run OPTIMIZE immediately following an update.

Since we ship audit logs for all services in a single JSON, we've defined a struct called requestParams which contains a union of the keys for all services. Eventually, we're going to create individual tables for each service, so we want to strip down the requestParams field for each table so that it contains only the relevant keys for the service. To accomplish this, we define UDF to strip away all keys in requestParams that have null values.

def stripNulls(raw):
  return json.dumps({i: raw.asDict()[i] for i in raw.asDict() if raw.asDict()[i] != None})
strip_udf = udf(stripNulls, StringType())

We instantiate a StreamReader from our bronze Delta Lake table to stream updates to our silver Delta Lake table.

bronzeDF = (
  spark
  .readStream
  .load("{}/streaming/bronze".format(sinkBucket))
)

We apply the following transformations when going from the bronze Delta Lake table to the silver Delta Lake table:

  • strip the null keys from requestParams and store the output as a string
  • parse email from userIdentity
  • parse an actual timestamp from the timestamp field and store it in date_time
  • drop the raw requestParams and userIdentity
%sql
select * from audit_logs.bronze
query = (
  bronzeDF
  .withColumn("flattened", strip_udf("requestParams"))
  .withColumn("email", col("userIdentity.email"))
  .withColumn("date_time", from_utc_timestamp(from_unixtime(col("timestamp")/1000), "UTC"))
  .drop("requestParams")
  .drop("userIdentity")
)

We then stream our changes from the bronze Delta Lake table to the silver Delta Lake table.

(query
 .writeStream
 .format("delta")
 .partitionBy("date")
 .outputMode("append")
 .option("checkpointLocation", "{}/checkpoints/silver".format(sinkBucket))
 .option("path", "{}/streaming/silver".format(sinkBucket))
 .option("mergeSchema", True)
 .trigger(once=True)
 .start()
)

Since the stream runs in a separate thread from the other processes, we want to ensure that the query finishes before we proceed. We can use information available via the SparkSession (specifically spark.streams.active) to check before proceeding.

while spark.streams.active != []:
  print("Waiting for streaming query to finish.")
  time.sleep(5)

Now that we've created the table on an external S3 bucket, we'll need to register the table to the internal Databricks Hive metastore to make access to the data easier for end users.

spark.sql("""
CREATE TABLE IF NOT EXISTS audit_logs.silver
USING DELTA
LOCATION '{}/streaming/silver'
""".format(sinkBucket))

Although Structured Streaming provides exactly once guarantees for processing, we can still write an assertion that checks if the number of records in the bronze Delta Lake table equals the number of records in the silver Delta Lake table.

assert(spark.table("audit_logs.bronze").count() == spark.table("audit_logs.silver").count())

As mentioned before, we now run OPTIMIZE on the silver Delta Lake table.

%sql
optimize audit_logs.silver

In the final step of our ETL process, we first define a UDF to parse the keys from the stripped down version of the original requestParams field.

def justKeys(string):
  return [i for i in json.loads(string).keys()]
just_keys_udf = udf(justKeys, StringType())

Define a function which accomplishes the following:

  • gathers the keys for each record for a given serviceName
  • creates a set of those keys (to remove duplicates)
  • creates a schema from those keys to apply to a given serviceName (if the serviceName does not have any keys in requestParms, we give it one key schema called placeholder)
  • write out to individual gold Delta Lake tables for each serviceName in the silver Delta Lake table
def flattenTable(serviceName, bucketName):
  flattenedStream = spark.readStream.load("{}/streaming/silver".format(bucketName))
  flattened = spark.table("audit_logs.silver")
  
  schema = StructType()
  
  keys = (
    flattened
    .filter(col("serviceName") == serviceName)
    .select(just_keys_udf(col("flattened")))
    .alias("keys")
    .distinct()
    .collect()
  )
  
  keysList = [i.asDict()['justKeys(flattened)'][1:-1].split(", ") for i in keys]
  
  keysDistinct = {j for i in keysList for j in i if j != ""}
  
  if len(keysDistinct) == 0:
    schema.add(StructField('placeholder', StringType()))
  else:
    for i in keysDistinct:
      schema.add(StructField(i, StringType()))
    
  (flattenedStream
   .filter(col("serviceName") == serviceName)
   .withColumn("requestParams", from_json(col("flattened"), schema))
   .drop("flattened")
   .writeStream
   .partitionBy("date")
   .outputMode("append")
   .format("delta")
   .option("checkpointLocation", "{}/checkpoints/gold/{}".format(bucketName, serviceName))
   .option("path", "{}/streaming/gold/{}".format(bucketName, serviceName))
   .option("mergeSchema", True)
   .trigger(once=True)
   .start()
  )

This cell gets a list of the distinct values in serviceName and collects them as a Python list so we can iterate.

serviceNameList = [i['serviceName'] for i in spark.table("audit_logs.silver").select("serviceName").distinct().collect()]

We then run the flattenTable function for each serviceName.

for i in serviceNameList:
  flattenTable(i, sinkBucket)

Since the stream runs in a separate thread from the other processes, we want to ensure that the query finishes before we proceed. We can use information available via the SparkSession (specifically spark.streams.active) to check before proceeding.

while spark.streams.active != []:
  print("Waiting for streaming query to finish.")
  time.sleep(5)

Now that we've created the table on an external S3 bucket, we'll need to register each of the gold tables to the internal Databricks Hive metastore to make access to the data easier for end users.

for i in serviceNameList:
  spark.sql("""
  CREATE TABLE IF NOT EXISTS audit_logs.{0}
  USING DELTA
  LOCATION '{1}/streaming/gold/{2}'
  """.format(i,sinkBucket,i))

Again, we run OPTIMIZE for each of the tables.

for i in serviceNameList:
  spark.sql("OPTIMIZE audit_logs.{}".format(i))

Repeat the assertion to ensure that the count of the silver Delta Lake table matches the sum of counts for each of the gold Delta Lake tables.

flattened_count = spark.table("audit_logs.silver").count()
total_count = 0
for i in serviceNameList:
  total_count += (spark.table("audit_logs.{}".format(i)).count())
assert(flattened_count == total_count)

We now have a gold Delta Lake table for each serviceName that Databricks tracks in its audit logs, which we can now use for monitoring and analysis.

%sql
SHOW TABLES IN audit_logs