Algorithmia Blog - Deploying AI at scale

Challenges productionizing embedding engines

what is an embedding

As many applied ML practitioners know, productionizing ML tools can be deceptively difficult.

At Algorithmia we’re always striving to make our algorithms the best in class, and we’ve recently made a series of performance and UX changes to our Document Classifier algorithm, and put work towards generalizing it to other problem spaces outside of NLP. These changes were dramatic; we reduced our lookup time from O(n) to O(log(n)) and drastically improved the user experience by reducing unnecessary clutter, but it was far from easy.

In this blog post, we’ll get into the technical aspects of what problems we had, and how we tackled them.
I warn you, there’s a lot to digest here; but it’s actually pretty simple to consume as a user. Check the Document Classifier out on Algorithmia, and let us know about interesting use cases you’re thinking about.

Still here? Great! Before we dive into the challenges we faced, lets have a quick refresher on recommendation engines and data embeddings.

What’s this embedding vector thing?

An embedding vector is a data representation that’s been mapped to a lower dimensional (and usually, dense) vector while still preserving the uniqueness and structure of the original data itself. If this uniqueness requirement is not preserved, then the process is not reversible and information is lost – losing information is generally not something we want, at least not without a good reason.

Before the advent of neural networks, creating dense embedding representations of high dimensional data, like language or images have been very difficult, if not impossible.

what is an embedding

Auto-encoding neural networks, or other unsupervised neural network architectures (such as skip-gram, nbow or siamese models) preserve the structure of and uniqueness of data while transforming it into a dense representation inside of their hidden layers.

word2vec embedding arithmetic

Source: Tensorflow

These representations have some very interesting properties; as they are vectors they are capable of having linear relationships with each other, which means you can do fun stuff like calculate the distances between entities or even perform basic arithmetic operations on them and still preserve meaning.

What’s a Recommendation Engine?

recommendations

A recommendation engine or Content Recommender describes a group of technologies used by industry to find the users and interests that best match you, so they can recommend you products and services that are more relevant to your interests.

How can we construct such a complex system? There are many techniques you can use to do this, but lets take a crack at it by looking at those embedding vectors again. If you’re able to convert your user data into an embedding (like documents or lists of documents) then you’re close. When you collect a bunch of embedding vectors and their real space equivalents together in a collection, the nearest neighbours to your specific profile/etc by using the k-nearest neighbours Algorithm.

That’s not all! if you label your vectors (or even a subset of them) in your collection, you can also get a zero-shot classifier for free by using our unsupervised embedding approach! Neat huh?

Challenges we faced

This sounds great, but creating and updating a collection requires state and memory utilization beyond the scope of a simple functional algorithm. Creating a tool that’s capable of handling 100 data points is relatively straightforward, creating one that can also handle 10 million data points is not.

Challenge 1 – Heap space restrictions

We very quickly ran into hardware constraints. Even on the most powerful AWS instances, we quickly hit out-of-memory JVM exceptions that severely limited the amount of data our vector collection could ever hold. We deduced that for a pure in-memory algorithm the maximum collection size was just over 50k samples (for a document classification task). Now this isn’t completely unusable as some problem spaces are comfortable with that number of samples – but for very large “big data” size problems this is definitely not enough.

Challenge 2 – Lookup performance issues

The next big issue we discovered is that even if we somehow could overcome our crippling limit on number of vectors we can work with at once, KNN traversal at the very least was O(n) or linear with search space – for any reasonably sized model the performance was dreadful, and was a definite blocker for any reasonable production level usability.

Challenge 3 – Preserving accuracy

There are many techniques that can address both #1 and #2, but constructing a mechanism that preserves the accuracy of a raw in-memory M*N multiplication can be really tricky. Finding the optimal trade-off between performance and accuracy is delicate balance, and must be handled with care.


The Solution, and what we built

As most algorithms are simple functions, we don’t generally have to contend with complex persistent state management. To manage state on Algorithma, we default to a file based storage system that we call our data API, which uses S3 as a backend. This can be somewhat cumbersome for a project as complex as this, however with a few modifications of an indexing algorithm we were able to come up with a relatively painless solution.

The index creation and tree searching algorithm we used was developed from a combination of standard Nearest Neighbour Search and a modified K-D Tree algorithm for tree creation. The big difference between a K-D tree and a typical binary tree is in it’s binning technique. Split planes are used to determine at each branch, which dimension in K space you split your data on.

function kdtree (list of points pointList, int depth)
{
    // Select axis based on depth so that axis cycles through all valid values
    var int axis := depth mod k;
        
    // Sort point list and choose median as pivot element
    select median by axis from pointList;
        
    // Create node and construct subtree
    node.location := median;
    node.leftChild := kdtree(points in pointList before median, depth+1);
    node.rightChild := kdtree(points in pointList after median, depth+1);
    return node;
}

Source: Wikipedia

This works great for low dimensional data with independent variables, however in some circumstances embedding vectors have “coupled” variables – or in simpler terms, dimensional “meaning” might be preserved in more than just 1 variable – such as the case for word embeddings.

Dimensionality Reduction – Locality Preserving Hashing

How do we combat this? Well thankfully we already have a pretty good way of reducing dimensionality – we can use the Locality Preserving Hashing algorithm.

Locality Preserving Hashing

This is a pretty neat technique, unlike with Locality Sensitive Hashing, our hashing function is not random – in fact it’s barely even a hashing function at all! We were able to preserve the structure of the input and improve the accuracy of our k-d tree bucketing tools dramatically with this technique.

Tree Construction – KD Trees

We don’t want to construct the tree entirely, otherwise we’d not only have to store every data point in memory at once. Instead, we get the constructor algorithm to continue forming new branches until it reached our ideal “bucket size”. Once that bucket size is reached, we save that bucket as a file and and save a fingerprint for the bucket file so we can easily find it later.

//Recursively constructs a k-d tree by iteratively splitting `dataset`
//into smaller and smaller lists of data.
//
//Differences from vanilla implementation:
// Split point is determined by average, not median
// Split dimension (axis) for each node layer is chosen randomly
// Once minBucketSize or maxDepth are reached, we stop splitting the data and instead we serialize the remaining branch.
  def buildTreeAverage(dataset: TreeBuilderMatrix, minBucketSize: Int, maxDepth: Int, depth: Int = 0): Validation[Exception, Node] = {
    val maxAxis = dataset.getMaxAxis
    val axis = Random.nextInt(maxAxis)
    val average = dataset.getAverageDist(axis)
    if( dataset.length > minBucketSize && depth < maxDepth) {
      val leftSide = dataset.getLeftByAverage(axis, average)
      val rightSide = dataset.getRightByAverage(axis, average)
      val left = buildTreeAverage(leftSide, minBucketSize, maxDepth, depth + 1).valueOr(t => return Failure(t))
      val right = buildTreeAverage(rightSide, minBucketSize, maxDepth, depth + 1).valueOr(t => return Failure(t))
      Success(TrunkNode(average, axis, left, right))
    }
    else {
      val leftSide = dataset.getLeftByAverage(axis, average)
      val rightSide = dataset.getRightByAverage(axis, average)
      val rightElms = rightSide.extractEmbeddings()
      val leftElms = leftSide.extractEmbeddings()
      Success(RealLeafNode(average, axis, leftElms, rightElms))
    }
  }

You might be able to see the resemblance with the pseudo code above, the big take away is the early exit, and the fact that we split on the average, not the median.

After we finish building this “shallow” tree, we serialize our buckets and throw them into the data API, and save the URL so we can find it again later.
Lets take a look at the code that describes this pruning mechanism:

sealed trait Node {
  val distance: Float
  val axis: Int
}

case class TrunkNode(distance: Float, axis: Int = 0, left: Node, right: Node) extends Node
case class RealLeafNode(distance: Float, axis: Int, left: List[Embedding], right: List[Embedding]) extends Node
case class RefLeafNode(distance: Float = 10000f, axis: Int = 0, left: String = "", right: String = "") extends Node

// Recursively searches through a k-d tree for "RealLeafNode"'s and replaces them with a reference.
// Uses putBucket to serialize and upload the bucket file.
  def pruneLeaves(node: Node, namespace: String, localBuffer: String,
  client: AlgorithmiaClient): Validation[Exception, Node] = node match {
    case trunk: TrunkNode => {
      val newLeft: Node = pruneLeaves(trunk.left, namespace, localBuffer, client).valueOr(t => return Failure(t))
      val newRight: Node = pruneLeaves(trunk.right, namespace, localBuffer, client).valueOr(t => return Failure(t))
      Success(TrunkNode(trunk.distance, trunk.axis, newLeft, newRight))
    }
    case leaf: RealLeafNode => {
      val leftPayload: List[Embedding] = leaf.left
      val rightPayload: List[Embedding] = leaf.right
      val leftPrune: DataFile = putBucket(leftPayload, namespace, localBuffer, client).valueOr(x => return Failure(x))
      val rightPrune: DataFile = putBucket(rightPayload, namespace, localBuffer, client).valueOr(x => return Failure(x))
      Success(RefLeafNode(leaf.distance, leaf.axis, leftPrune.getName, rightPrune.getName))
    }
    case leaf: RefLeafNode => Success(leaf)
  }

// Serializes the bucket to a binary file, then uploads it to the algorithmia Data API.
  private def putBucket(payload: List[Embedding], namespace: String,localBuffer: String, client: AlgorithmiaClient): Validation[Exception, DataFile] = {
    val savedBucket = Utils.writeToTempBinary(payload, localBuffer)
      .getOrElse(return Failure(new AlgorithmException("E5443: unable to write bucket to file.")))
    AlgoRelated.putBucket(savedBucket, namespace, client)
  }

You can see that we replace the RealLeafNode terminal nodes (which contain physical buckets) with RefLeafNode nodes, in this process we upload the physical buckets to the data API. Each Reference node contains the remote file path to the bucket we just uploaded, so it can be easily referenced and retrieved later!

Index creation

Our Tree functions similar to a database index, however our lookup search is entirely approximate – there are no guarantees that we will find the best element but we can get very close, however with more split planes the accuracy of the search increases exponentially.

When we get an inference request we just traverse down that same tree and find the nearest bucket. We can’t just let it sit there in memory however, Algorithmia requires state to be managed outside of the system in the data API. Lets serialize this index and store it in the data API before the training process finishes.

sealed trait Index {
  val splittingPlanes: List[Embedding]
  val base: Node
  val dataFiles: List[String]
  val keepPrivate: Boolean
}

case class NormalIndex(splittingPlanes: List[Embedding], base: Node, dataFiles: List[String], nbowSize: Int, keepPrivate: Boolean) extends Index
case class UnitaryIndex(splittingPlanes: List[Embedding], base: Node, dataFiles: List[String], keepPrivate: Boolean) extends Index

When doing any kind of inference or prediction operation, the index files will be the first things downloaded. We then check which buckets contain the best embeddings for our sample by traversing the k-d tree contained within each index file, and download just the buckets that are closest to our sample.

Lets see what the whole construction process looks like in flow chart form:

constructon process


Inference

Our inference technique relies heavily on our indexing tools and k-nearest neighbours. For a more detailed look, take a look at the inference flow chart below:

inference process


Our Results

That’s enough code and flow charts, lets see how well our modifications performed in the wild.

We constructed a dataset of research paper abstracts, scraped from www.arxiv.org – and labeled them according to their topics. We then threw them into the Document Classifier, which is an example use of the algorithm we described here – and were off on the hunt for performance.

Old Model

Our previous version did not have indexing capabilities and had a maximum of 50k samples, this restriction reduced the accuracy 80.66%. Pretty good. At the time of testing, it was more accurate than Facebook’s FastText which had an accuracy of 78.36%.

The compute time per document, however was far too high. Averaging close to 5 s per document – evaluating the entire test set took more than 10 hours to complete! This took way too long, and really restricted the types of projects that the algorithm could be used for.

New Model

Contrast that with the table using the full 100k document dataset:

k accuracy avg compute per document (ms)
10 82.42% 97.8
5 79.27% 57.92
3 76.71% 42.69
1 62.79% 34.99

where K is the maximum number of buckets to search through during inference.

note: Compute times are averaged over the entire test set, which were processed in batches greater than 100. Your averages might be different.


In Conclusion

Using some new data management techniques along side good old fashioned Computer Science, we were able to dramatically improve the performance of our algorithm and even generalize it to solve other types of unsupervised environments in the future!

Did we spark your curiosity? Head on over to the Algorithmia and check out the Document Classifer algorithm and see just what kinds of things it can do.

Algorithm Engineer at Algorithmia, empowering users by building state-of-the-art, production-ready algorithms to solve their unique challenges

More Posts - Website

Follow Me:
TwitterLinkedIn