Feature Caching Redis

Created:
Category: Data Science
Tags: Tools Scala Spark

Hello there blog, it has been too long. I've been in America (to Disrupt and visiting our San Diego office) and worked on a bunch of projects in the mean time, but I want to share some useful info on putting preprocessed machine-learning features from Spark into Redis. I am still experimenting with different solutions but here are some options I'm considering.

Say there is a ML pipeline that needs to go in production, and a bunch of the feature-processing data can be prepared beforehand. It would be best to put these in some form of cache, close to where the inference will take place. A Hash-map in-memory is a solution but requires repopulating each run. Another common solution is saving the features in a Redis-cache.

#1

Luckily Redis has helped a bit here with a spark-redis connector. It enables us to directly upload a DataFrame into Redis. Given a Redis-instance running on the localhost at port 6379:

val spark = SparkSession.builder
    .appName("Uploader number 1")
    .master("local")
    .config("spark.redis.host", "localhost")
    .config("spark.redis.port", "6379")
    .getOrCreate()
import spark.implicits._

dfToUpload.write
    .format("org.apache.spark.sql.redis")
    .option("table", "tablename")
    .option("key.column", "nameKeyColumn")
    .save()

This is by far the easiest method and makes sure that on the receiving side we can get individual columns from a row, because it is stored on Redis' side by a hash (field). The hgetall command can be used to get all the features in one go.

(Using Scala-redis)

val keyVal = "tablename:aKeyName"
val r = new RedisClient("localhost", 6379)
val sredResp: Option[Map[String, String]] = r.hgetall[String, String](keyVal)
r.close()

This works well, but means a lot of converting between strings, and a representation of the feature vector. In a lot of cases we want to get all the features at once per row/sample. Something similar can be done using Spark-redis, by using it's RDD support. The idea is to encode the whole sample using some form of serialization and then store it using key, so it can be retrieved at once. The downside of this approach is that spark-redis works with strings and we will need to encode it in a string.

#2

For this solution I use Kryo like Spark in combination with Twitter's chill as serializer (I choose this because I want to use Scala / JVM on the inferencing side).

val spark = SparkSession.builder
    .appName("Uploader number 2")
    .master("local")
    .config("spark.redis.host", "localhost")
    .config("spark.redis.port", "6379")
    .getOrCreate()
import spark.implicits._

val dfToConvert = dfIn.as[FeatureCaseClass]
val instantiator = new ScalaKryoInstantiator
instantiator.setRegistrationRequired(false)
val kryo = instantiator.newKryo()
val encoder = Base64.getEncoder

val keyedRDD = dfToConvert.rdd.keyBy(_.nameKeyColumn).map(tup => {
    val output = new Output(512, -1)    // Guess the side but allows it to grow
    kryo.writeObject(output, tup._2)
    tup._1 -> encoder.encodeToString(output.getBuffer)
})

val sc = spark.sparkContext
sc.toRedisKV(keyedRDD)

Now getting the data out (this time using Jedis and a shared library with the case-class definition).

val keyVal = "keyname"
val jedis = new Jedis("localhost")

val instantiator = new ScalaKryoInstantiator
instantiator.setRegistrationRequired(false)
val kryo = instantiator.newKryo()
val decoder = Base64.getDecoder

val jedResp = jedis.get(keyVal)
if(jedResp == null) {
    println("Couldn't find key")
    // etc.
}
val decodedBytes = decoder.decode(jedResp)
val input = new Input(decodedBytes)
val dataBack =
    kryo.readObject(input, classOf[datatype.FeatureCaseClass])

This bundles the features neatly together but is not really efficient because of the Base64 encoding/decoding step needed.

It is also possible to leave the Spark-redis connector for what it is and push the samples by using some client library (like Jedis).

#3

val spark = SparkSession.builder
    .appName("Uploader number 3")
    .master("local")
    .config("spark.redis.host", "localhost")
    .config("spark.redis.port", "6379")
    .getOrCreate()
import spark.implicits._

val dfToConvert = dfIn.as[FeatureCaseClass]
val instantiator = new ScalaKryoInstantiator
instantiator.setRegistrationRequired(false)
val kryo = instantiator.newKryo()

val dfToUpload = dfToConvert.map(fcc => {
    val output = new Output(512, -1)
    kryo.writeObject(output, fcc)
    fcc.nameKeyColumn.getBytes -> output.getBuffer
})
val results = dfToUpload.mapPartitions(pairList => {
    val jedis = new Jedis("localhost")  // Every spark partition its own client
    pairList.map(pair => {
        jedis.set(pair._1, pair._2)
    })
})

// Force evaluation through writing return info
results.write.mode(SaveMode.Overwrite).csv("upload_info")

And now for reading the output (with Jedis again).

val keyVal = "keyname"
val jedis = new Jedis("localhost")

val instantiator = new ScalaKryoInstantiator
instantiator.setRegistrationRequired(false)
val kryo = instantiator.newKryo()

val jedResp = jedis.get(keyVal.getBytes)
if(jedResp == null) {
    println("Couldn't find key")
    // etc.
}
val input = new Input(jedResp)
val dataBack =
    kryo.readObject(input, classOf[datatype.FeatureCaseClass])