Production ML with Snowflake and dbt

Β· 1185 words Β· 6 minute read

Running Production ML with Snowflake and dbt πŸ”—

Snowflake runs Python/pyspark now, which is cool. And so, it lets you train models and do predictions and whatnot. But the world has come a long way since training a model was impressive. Nowadays, training is table-stakes but serving, tracking, monitoring and everything else related to deploying and maintaining models in production. Model registries, ML Flow, Weights and biases, model serving via REST APIs… all in a day’s work for the people at Databricks or similar places. But if you are on snowflake and even want the bare minimum of a production-grade ML system, what do you do?

A sketch πŸ”—

I have cobbeled together sketch of something that would satisfy at least the basic requirements while using nothing more than out-of-the-box snowflake, and the dbt that you already have and use for all your normal data stuff.

Conceptually, we use python in dbt to train an ML model, and save the model to a stage (aka blob storage). We version the models, and the table that results from the dbt model contains the ID of the model that was trained, and the test-scores (accuracy, RMSE, whatever you fancy) for the latest run. Add a snapshot to keep track of the history.

This saved ML-model can then be used by other dbt-models in order to run predictions. In other words, we split up the training and the prediction step in two different models (or 1 + N models, if you need to). Because we have a snapshot of the training dbt-model, we have a log of all the training runs, the accuracies, timestamps and whatever else. And we have a link to the ML-model (artifact) so we can programatically pick up the newest model, the best model, or whatever we want.

No fancy tools, no product named “model registry”, but we do have working ML in a system that makes sense and is governed by the same things that govern the rest of your data.

As a bonus: In dbt-snowflake-python, the dbt object that is available includes a reference to the model (dbt.this.database, dbt.this.schema and dbt.this.identifier). By carefully constructing the location of the ML-model, we can keep the separation between development and production along the same lines as we separate dbt models. Except we want to use a snapshot, which in dbt doesn’t get prefixed and is always just production. That should be manageable though, by including the model database/schema in the table so that we can filter on it if need be.

What does this look like? πŸ”—

One of my dbt projects have a small PoC implementation of this. It uses a small walmart sales dataset from Kaggle as an example, and the main parts are here:

The sketch consists of four models:

  • walmart_sales is the table containing the real data. The origin, if you will.
  • walmart_weekly_sales_train is a python model that trains an ML model and saves it.
  • walmart_weekly_sales_predict is a python model that uses the ML model for prediction purposes.
  • training_runs_snapshot is a regular snapshot of walmart_weekly_sales_train - it is so simple, I won’t even talk about it. But you can find it here:

Training πŸ”—

The model to train the data:

from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
import datetime
import uuid
import joblib

def model(dbt, session):
        materialized = "table",
        packages = ["scikit-learn", "joblib", "pandas"],
        snowflake_warehouse = "PARK"
    stillingshistorikk_person = dbt.ref("walmart_sales").to_pandas().dropna()

    X = stillingshistorikk_person.drop(columns=['WEEKLY_SALES', "SALES_WEEK_START"])
    y = stillingshistorikk_person['WEEKLY_SALES']
    # categorical_columns = X.select_dtypes(include=['object']).columns
    # X_ohe = pd.get_dummies(X, columns=categorical_columns, drop_first=True)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    model = LinearRegression(), y_train)

    y_pred = model.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)
    r2 = r2_score(y_test, y_pred)
    now = datetime.datetime.utcnow()
    model_uuid = str(uuid.uuid4())
    database = dbt.this.database
    schema = dbt.this.schema
    identifier = dbt.this.identifier

    accuracy_df = pd.DataFrame({
        "UUID": [model_uuid], 
        "MSE": [mse], 
        "R2": [r2], 
        "TIMESTAMP": [now],
        "DATABASE": [database],
        "SCHEMA": [schema],
        "IDENTIFIER": [identifier]

    stage = f"@dbthouse.develop.ml_models/{database}/{schema}/{identifier}/"
    model_loc = f"tmp/{model_uuid}.pkl"
    joblib.dump(model, model_loc)
    session.file.put(model_loc, stage)

    return accuracy_df

Most of this code is very boilerplate python sklearn linear regression. Feel free to fill in some more fancy model, but this is sufficient as an illustration. We read the data with the special dbt function dbt.ref("walmart_sales"), which means dbt is able to keep track of lineage and you can run things like dbt run -s +my_model and dbt is still able to understand what models are upstream of our python models.

We use a stage named @dbthouse.develop.ml_models for storing our model, and we use joblib.dump to persist our model to a local file (on the /tmp file path) and the (snowflake specific?) function session.file.put to store the model on permanent storage in our stage. For production I’d recommend using an external stage, but this works with any stage.

Note that we are careful with designing the location of the model. We define some convenience variables:

model_uuid = str(uuid.uuid4())
database = dbt.this.database
schema = dbt.this.schema
identifier = dbt.this.identifier

We name the file {model_uuid}.pkl, and save it at @dbthouse.develop.ml_models/{database}/{schema}/{identifier}/. This way the identity is unique, and we maintain some logic to the stage location.

Finally, we create a tiny dataframe populated with some key metrics and information from our training run. It does not contain predictions on real data, we do that separately to keep things separated and clean.

Also note that I’m using a specialized warehouse (PARK) for these snowpark jobs. A regular x-small probably won’t do.

Inference πŸ”—

When we want to run inference, we need to perform two steps:

  1. Find the name of the model we want to use
  2. Load it into python
  3. Run the actual inference

The code that does this:

import sys
from joblib import load
import pandas as pd

def model(dbt, session):
        materialized = "table",
        packages = ["scikit-learn", "joblib", "pandas"],
        snowflake_warehouse = "PARK"

    model_id = dbt.ref("walmart_weekly_sales_train").to_pandas().sort_values(by="timestamp", ascending=False)["id"].values[0]

    model = load(f"/tmp/mdl/{model_id}.pkl.gz")

    walmart_sales = dbt.ref("walmart_sales").to_pandas().dropna()
    df = walmart_sales.copy()

    y = df['WEEKLY_SALES']
    X = df.drop(columns=['WEEKLY_SALES', "SALES_WEEK_START"])

    y_pred = model.predict(X)
    walmart_sales['PREDICTED_SALES'] = y_pred

    return walmart_sales

In our simplified case, we read the model ID not from the snapshot but from the original table produced by the training run. In practice, we might want to read from the snapshot table and filter to avoid accidentally using a model written during development.

After that, we load the model by first copying it back to the local file system with session.file.get and using joblib’s load function to get it into python. From there, we do what everyone does. Call predict, add a column with predicted values to the dataset, and we simply return it.

Best of all, all this runs in dbt, you can choose when to run the training job, and you keep track of the training runs and the models. All in fairly few lines of code.

Because we use dbt refs, we wil be able to keep the lineage and we should be able to run dbt run +walmart_weekly_sales_predict. Since I’m not actually using the snapshot i my example, this would run the training as well. But if I substitute in the snapshot, the snapshot will not be run and so the upstream training job will not run.