Extend Spark ML for your own model/transformer types

How to use the wordcount example as a starting point (and you thought you’d escape the wordcount example).

By Holden Karau
February 2, 2017
While Spark ML pipelines have a wide variety of algorithms, you may find yourself wanting additional functionality without having to leave the pipeline model. In Spark MLlib, this isn’t much of a problem—you can manually implement your algorithm with RDD transformations and keep going from there. For Spark ML pipelines, the same approach can work, but we lose some of the nicely integrated properties of the pipeline, including the ability to automatically run meta-algorithms, such as cross-validation parameter search. In this article, you will learn how to extend the Spark ML pipeline model using the standard wordcount example as a starting point (one can never really escape the intro to big data wordcount example).

To add your own algorithm to a Spark pipeline, you need to implement either Estimator or Transformer, which implements the PipelineStage interface. For algorithms that don’t require training, you can implement the Transformer interface, and for algorithms with training you can implement the Estimator interface—both in org.apache.spark.ml (both of which implement the base PipelineStage). Note that training is not limited to complicated machine learning models; even the MinMaxScaler requires training to determine the range. If they need training, they must be constructed as Estimator rather than Transformer.

Using PipelineStage directly does not work, since inside of the pipeline fitting reflection is used, which assumes all stages are either an Estimator or a Transformer.

In addition to the obvious transform or fit function, all pipeline stages need to provide transformSchema, and a copy constructor or implement a class, which provides these for you—copy is used to make a copy of the current stage, with any newly specified params merged in, and can simply be called defaultCopy (unless your class has special constructor considerations).

The start of a pipeline stage, as well as the copy delegation, is shown—transformSchema must produce what the expected output of your pipeline stage is based on any parameters set and an input schema. Most pipeline stages simply add new fields; very few drop previous fields in case they are needed, but this can sometimes result in records containing more data than is required downstream, negatively impacting performance. If you find this is a problem in your pipeline, you can create your own stage to drop unnecessary fields.

class HardCodedWordCountStage(override val uid: String) extends Transformer {
  def this() = this(Identifiable.randomUID("hardcodedwordcount"))

  def copy(extra: ParamMap): HardCodedWordCountStage = {

In addition to producing the output schema, the transformSchema function should validate that the input schema is suitable for the stage (e.g., the input column is of the expected type).

This is also where you should perform validation on your stages parameters.

A simple transformSchema for string inputs and a vector output, with hard coded input and output columns, is illustrated as follows.

override def transformSchema(schema: StructType): StructType = {
     // Check that the input type is a string
    val idx = schema.fieldIndex("happy_pandas")
    val field = schema.fields(idx)
    if (field.dataType != StringType) {
      throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
    // Add the return field
    schema.add(StructField("happy_panda_counts", IntegerType, false))

Algorithms that do not require training can be implemented very simply using the Transformer interface. Since this is the simplest pipeline stage, you can start with implementing a simple transformer, which counts the number of words on the input column.

 def transform(df: Dataset[_]): DataFrame = {
    val wordcount = udf { in: String => in.split(" ").size }

To get the most of the pipeline interface, you will want to make your pipeline stage configurable using the params interface.

While the params interface is public, sadly the common default params that are commonly used inside of Spark are private, so you will end up with some amount of code duplication. In addition to allowing users to specify values, parameters can also contain some basic validation logic (e.g., the regularization parameter must be set to a non-negative value). The two most common parameters are input column and output column, which you can add to your model relatively simply.

In addition to string params, any other type can be used, including lists of strings for things like stop words, and doubles for things like stop words.

class ConfigurableWordCount(override val uid: String) extends Transformer {
  final val inputCol= new Param[String](this, "inputCol", "The input column")
  final val outputCol = new Param[String](this, "outputCol", "The output column")

 ; def setInputCol(value: String): this.type = set(inputCol, value)

  def setOutputCol(value: String): this.type = set(outputCol, value)

  def this() = this(Identifiable.randomUID("configurablewordcount"))

  def copy(extra: ParamMap): HardCodedWordCountStage = {

  override def transformSchema(schema: StructType): StructType = {
    // Check that the input type is a string
    val idx = schema.fieldIndex($(inputCol))
    val field = schema.fields(idx)
    if (field.dataType != StringType) {
      throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
    // Add the return field
    schema.add(StructField($(outputCol), IntegerType, false))

  def transform(df: Dataset[_]): DataFrame = {
    val wordcount = udf { in: String => in.split(" ").size }
    df.select(col("*"), wordcount(df.col($(inputCol))).as($(outputCol)))

Algorithms that do require training can be implemented using the Estimator interface—although, for many algorithms, the org.apache.spark.ml.Predictor or org.apache.spark.ml.classificationClassifier helper classes are easier to implement. The primary difference between the Estimator and Transformer interfaces is that rather than directly expressing your transformation on the input, you will first have a training step in the form of a train function. A string indexer is one of the simplest estimators you can implement, and while it’s already available in Spark, is still a good illustration of how to use the estimator interface.

trait SimpleIndexerParams extends Params {
  final val inputCol= new Param[String](this, "inputCol", "The input column")
  final val outputCol = new Param[String](this, "outputCol", "The output column")

class SimpleIndexer(override val uid: String) extends Estimator[SimpleIndexerModel] with SimpleIndexerParams {

  def setInputCol(value: String) = set(inputCol, value)

  def setOutputCol(value: String) = set(outputCol, value)

  def this() = this(Identifiable.randomUID("simpleindexer"))

  override def copy(extra: ParamMap): SimpleIndexer = {

  override def transformSchema(schema: StructType): StructType = {
    // Check that the input type is a string
    val idx = schema.fieldIndex($(inputCol))
    val field = schema.fields(idx)
    if (field.dataType != StringType) {
      throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
    // Add the return field
    schema.add(StructField($(outputCol), IntegerType, false))

  override def fit(dataset: Dataset[_]): SimpleIndexerModel = {
    import dataset.sparkSession.implicits._
    val words = dataset.select(dataset($(inputCol)).as[String]).distinct
    new SimpleIndexerModel(uid, words)
 ; }

class SimpleIndexerModel(
  override val uid: String, words: Array[String]) extends Model[SimpleIndexerModel] with SimpleIndexerParams {

  override def copy(extra: ParamMap): SimpleIndexerModel = {

  private val labelToIndex: Map[String, Double] = words.zipWithIndex.
    map{case (x, y) => (x, y.toDouble)}.toMap

  override def transformSchema(schema: StructType): StructType = {
    // Check that the input type is a string
    val idx = schema.fieldIndex($(inputCol))
    val field = schema.fields(idx)
    if (field.dataType != StringType) {
      throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
    // Add the return field
    schema.add(StructField($(outputCol), IntegerType, false))

  override def transform(dataset: Dataset[_]): DataFrame = {
    val indexer = udf { label: String => labelToIndex(label) }

If you are implementing an iterative algorithm, you may wish to consider caching the input data automatically if it’s not already cached, or allow the user to specify a persistence level.

The Predictor interface adds the two most common parameters (input and output columns) as labels column, features column, and prediction column—and automatically handles the schema transformation for us.

The Classifier interface does much the same, except it also adds a rawPredictionColumn and provides tools to detect the number of classes (getNumClasses) as well as convert the input DataFrame to an RDD of LabeledPoints (making it easier to wrap legacy MLlib classification algorithms).

If you are implementing a regression or clustering interface, there is no public base set of interfaces to use, so you will need to use the generic Estimator interface.

// Simple Bernouli Naive Bayes classifier - no sanity checks for brevity
// Example only - not for production use.
class SimpleNaiveBayes(val uid: String)
    extends Classifier[Vector, SimpleNaiveBayes, SimpleNaiveBayesModel] {

  def this() = this(Identifiable.randomUID("simple-naive-bayes"))

  override def train(ds: Dataset[_]): SimpleNaiveBayesModel = {
    import ds.sparkSession.implicits._
    // Note: you can use getNumClasses and extractLabeledPoints to get an RDD instead
    // Using the RDD approach is common when integrating with legacy machine learning code
    // or iterative algorithms which can create large query plans.
    // Here we use Datasets since neither of those apply.

    // Compute the number of documents
    val numDocs = ds.count
    // Get the number of classes.
    // Note this estimator assumes they start at 0 and go to numClasses
    val numClasses = getNumClasses(ds)
    // Get the number of features by peaking at the first row
    val numFeatures: Integer = ds.select(col($(featuresCol))).head
    // Determine the number of records for each class
    val groupedByLabel = ds.select(col($(labelCol)).as[Double]).groupByKey(x => x)
    val classCounts = groupedByLabel.agg(count("*").as[Long])
    // Select the labels and features so we can more easily map over them.
    // Note: we do this as a DataFrame using the untyped API because the Vector
    // UDT is no longer public.
    val df = ds.select(col($(labelCol)).cast(DoubleType), col($(featuresCol)))
    // Figure out the non-zero frequency of each feature for each label and
    // output label index pairs using a case clas to make it easier to work with.
    val labelCounts: Dataset[LabeledToken] = df.flatMap {
      case Row(label: Double, features: Vector) =>
        features.toArray.zip(Stream from 1)
          .filter{vIdx => vIdx._2 == 1.0}
          .map{case (v, idx) => LabeledToken(label, idx)}
    // Use the typed Dataset aggregation API to count the number of non-zero
    // features for each label-feature index.
    val aggregatedCounts: Array[((Double, Integer), Long)] = labelCounts
      .groupByKey(x => (x.label, x.index))

    val theta = Array.fill(numClasses)(new Array[Double](numFeatures))

    // Compute the denominator for the general prioirs
    val piLogDenom = math.log(numDocs + numClasses)
    // Compute the priors for each class
    val pi = classCounts.map{case(_, cc) =>
      math.log(cc.toDouble) - piLogDenom }.toArray

    // For each label/feature update the probabilities
    aggregatedCounts.foreach{case ((label, featureIndex), count) =>
      // log of number of documents for this label + 2.0 (smoothing)
      val thetaLogDenom = math.log(
        classCounts.get(label).map(_.toDouble).getOrElse(0.0) + 2.0)
      theta(label.toInt)(featureIndex) = math.log(count + 1.0) - thetaLogDenom
    // Unpersist now that we are done computing everything
    // Construct a model
    new SimpleNaiveBayesModel(uid, numClasses, numFeatures, Vectors.dense(pi),
      new DenseMatrix(numClasses, theta(0).length, theta.flatten, true))

  override def copy(extra: ParamMap) = {

// Simplified Naive Bayes Model
case class SimpleNaiveBayesModel(
  override val uid: String,
  override val numClasses: Int,
  override val numFeatures: Int,
  val pi: Vector,
  val theta: DenseMatrix) extends
    ClassificationModel[Vector, SimpleNaiveBayesModel] {

  override def copy(extra: ParamMap) = {

  // We have to do some tricks here because we are using Spark's
  // Vector/DenseMatrix calculations - but for your own model don't feel
  // limited to Spark's native ones.
  val negThetaArray = theta.values.map(v => math.log(1.0 - math.exp(v)))
  val negTheta = new DenseMatrix(numClasses, numFeatures, negThetaArray, true)
  val thetaMinusNegThetaArray = theta.values.zip(negThetaArray)
    .map{case (v, nv) => v - nv}
  val thetaMinusNegTheta = new DenseMatrix(
    numClasses, numFeatures, thetaMinusNegThetaArray, true)
  val onesVec = Vectors.dense(Array.fill(theta.numCols)(1.0))
  val negThetaSum: Array[Double] = negTheta.multiply(onesVec).toArray

  // Here is the prediciton functionality you need to implement - for ClassificationModels
  // transform automatically wraps this - but if you might benefit from broadcasting your model or
  // other optimizations you can also override transform.
  def predictRaw(features: Vector): Vector = {
    // Toy implementation - use BLAS or similar instead
    // the summing of the three vectors but the functionality isn't exposed.
      .map{case (x, y) => x + y}.zip(negThetaSum).map{case (x, y) => x + y}

If you simply need to modify an existing algorithm, you can (by pretending to be in the org.apache.spark project) extend it.

Now you know how to extend Spark’s ML Pipeline API with your own stages. If you get lost, a good reference is the algorithms inside of Spark’s itself—while they do sometimes use internal APIs, for the most part they implement public interfaces in the same way that you will want to.

