Pregunta Cómo rodar un estimador personalizado en PySpark mllib


Estoy tratando de construir una costumbre simple Estimator en PySpark MLlib. yo tengo aquí que es posible escribir un transformador personalizado, pero no estoy seguro de cómo hacerlo en un Estimator. Yo tampoco entiendo qué @keyword_only y por qué necesito tantos setters y getters. Scikit-learn parece tener un documento adecuado para modelos personalizados (mira aquí pero PySpark no.

Pseudo código de un modelo de ejemplo:

class NormalDeviation():
    def __init__(self, threshold = 3):
    def fit(x, y=None):
       self.model = {'mean': x.mean(), 'std': x.std()]
    def predict(x):
       return ((x-self.model['mean']) > self.threshold * self.model['std'])
    def decision_function(x): # does ml-lib support this?

10
2018-05-17 08:04


origen


Respuestas:


En general, no hay documentación porque, en cuanto a Spark 1.6 / 2.0, la mayoría de las API relacionadas no están destinadas a ser públicas. Debería cambiar en Spark 2.1.0 (ver SPARK-7146)

La API es relativamente compleja porque tiene que seguir convenciones específicas para hacer Transformer o Estimator compatible con Pipeline API. Algunos de estos métodos pueden ser necesarios para funciones como lectura y escritura o búsqueda de grillas. Otro, como keyword_only son simples ayudantes y no estrictamente necesarios.

Suponiendo que ha definido las siguientes combinaciones para el parámetro medio:

from pyspark.ml.pipeline import Estimator, Model, Pipeline
from pyspark.ml.param.shared import *
from pyspark.sql.functions import avg, stddev_samp


class HasMean(Params):

    mean = Param(Params._dummy(), "mean", "mean", 
        typeConverter=TypeConverters.toFloat)

    def __init__(self):
        super(HasMean, self).__init__()

    def setMean(self, value):
        return self._set(mean=value)

    def getMean(self):
        return self.getOrDefault(self.mean)

parámetro de desviación estándar:

class HasStandardDeviation(Params):

    stddev = Param(Params._dummy(), "stddev", "stddev", 
        typeConverter=TypeConverters.toFloat)

    def __init__(self):
        super(HasStandardDeviation, self).__init__()

    def setStddev(self, value):
        return self._set(stddev=value)

    def getStddev(self):
        return self.getOrDefault(self.stddev)

y umbral:

class HasCenteredThreshold(Params):

    centered_threshold = Param(Params._dummy(),
            "centered_threshold", "centered_threshold",
            typeConverter=TypeConverters.toFloat)

    def __init__(self):
        super(HasCenteredThreshold, self).__init__()

    def setCenteredThreshold(self, value):
        return self._set(centered_threshold=value)

    def getCenteredThreshold(self):
        return self.getOrDefault(self.centered_threshold)

podrías crear básico Estimator como sigue:

class NormalDeviation(Estimator, HasInputCol, 
        HasPredictionCol, HasCenteredThreshold):

    def _fit(self, dataset):
        c = self.getInputCol()
        mu, sigma = dataset.agg(avg(c), stddev_samp(c)).first()
        return (NormalDeviationModel()
            .setInputCol(c)
            .setMean(mu)
            .setStddev(sigma)
            .setCenteredThreshold(self.getCenteredThreshold())
            .setPredictionCol(self.getPredictionCol()))

class NormalDeviationModel(Model, HasInputCol, HasPredictionCol,
        HasMean, HasStandardDeviation, HasCenteredThreshold):

    def _transform(self, dataset):
        x = self.getInputCol()
        y = self.getPredictionCol()
        threshold = self.getCenteredThreshold()
        mu = self.getMean()
        sigma = self.getStddev()

        return dataset.withColumn(y, (dataset[x] - mu) > threshold * sigma)

Finalmente podría usarse de la siguiente manera:

df = sc.parallelize([(1, 2.0), (2, 3.0), (3, 0.0), (4, 99.0)]).toDF(["id", "x"])

normal_deviation = NormalDeviation().setInputCol("x").setCenteredThreshold(1.0)
model  = Pipeline(stages=[normal_deviation]).fit(df)

model.transform(df).show()
## +---+----+----------+
## | id|   x|prediction|
## +---+----+----------+
## |  1| 2.0|     false|
## |  2| 3.0|     false|
## |  3| 0.0|     false|
## |  4|99.0|      true|
## +---+----+----------+

11
2018-05-17 14:51