wake-up-neo.com

Wie schwenke ich DataFrame?

Ich fange an, Spark DataFrames zu verwenden, und ich muss die Daten schwenken können, um mehrere Spalten aus einer Spalte mit mehreren Zeilen zu erstellen. In Scalding gibt es dafür eine integrierte Funktionalität. Ich glaube an Pandas in Python, aber ich kann nichts für den neuen Spark-Datenrahmen finden.

Ich gehe davon aus, dass ich eine benutzerdefinierte Funktion schreiben kann, die dies tut, aber ich bin mir nicht mal sicher, wie ich damit anfangen soll, zumal ich Spark bin. Ich weiß, wie das geht, mit eingebauten Funktionen oder mit Vorschlägen, wie man etwas in Scala schreibt. Es wird sehr geschätzt. 

38
J Calbreath

Wie bereits erwähnt by David Anderson Spark bietet seit Version 1.6 die pivot-Funktion. Die allgemeine Syntax sieht wie folgt aus:

df
  .groupBy(grouping_columns)
  .pivot(pivot_column, [values]) 
  .agg(aggregate_expressions)

Verwendungsbeispiele im nycflights13 - und csv-Format:

Python:

from pyspark.sql.functions import avg

flights = (sqlContext
    .read
    .format("csv")
    .options(inferSchema="true", header="true")
    .load("flights.csv")
    .na.drop())

flights.registerTempTable("flights")
sqlContext.cacheTable("flights")

gexprs = ("Origin", "dest", "carrier")
aggexpr = avg("arr_delay")

flights.count()
## 336776

%timeit -n10 flights.groupBy(*gexprs ).pivot("hour").agg(aggexpr).count()
## 10 loops, best of 3: 1.03 s per loop

Scala:

val flights = sqlContext
  .read
  .format("csv")
  .options(Map("inferSchema" -> "true", "header" -> "true"))
  .load("flights.csv")

flights
  .groupBy($"Origin", $"dest", $"carrier")
  .pivot("hour")
  .agg(avg($"arr_delay"))

Java:

import static org.Apache.spark.sql.functions.*;
import org.Apache.spark.sql.*;

Dataset<Row> df = spark.read().format("csv")
        .option("inferSchema", "true")
        .option("header", "true")
        .load("flights.csv");

df.groupBy(col("Origin"), col("dest"), col("carrier"))
        .pivot("hour")
        .agg(avg(col("arr_delay")));

R/SparkR:

library(magrittr)

flights <- read.df("flights.csv", source="csv", header=TRUE, inferSchema=TRUE)

flights %>% 
  groupBy("Origin", "dest", "carrier") %>% 
  pivot("hour") %>% 
  agg(avg(column("arr_delay")))

R/sparklyr

library(dplyr)

flights <- spark_read_csv(sc, "flights", "flights.csv")

avg.arr.delay <- function(gdf) {
   expr <- invoke_static(
      sc,
      "org.Apache.spark.sql.functions",
      "avg",
      "arr_delay"
    )
    gdf %>% invoke("agg", expr, list())
}

flights %>% 
  sdf_pivot(Origin + dest + carrier ~  hour, fun.aggregate=avg.arr.delay)

SQL:

CREATE TEMPORARY VIEW flights 
USING csv 
OPTIONS (header 'true', path 'flights.csv', inferSchema 'true') ;

 SELECT * FROM (
   SELECT Origin, dest, carrier, arr_delay, hour FROM flights
 ) PIVOT (
   avg(arr_delay)
   FOR hour IN (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
                13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)
 );

Beispieldaten:

"year","month","day","dep_time","sched_dep_time","dep_delay","arr_time","sched_arr_time","arr_delay","carrier","flight","tailnum","Origin","dest","air_time","distance","hour","minute","time_hour"
2013,1,1,517,515,2,830,819,11,"UA",1545,"N14228","EWR","IAH",227,1400,5,15,2013-01-01 05:00:00
2013,1,1,533,529,4,850,830,20,"UA",1714,"N24211","LGA","IAH",227,1416,5,29,2013-01-01 05:00:00
2013,1,1,542,540,2,923,850,33,"AA",1141,"N619AA","JFK","MIA",160,1089,5,40,2013-01-01 05:00:00
2013,1,1,544,545,-1,1004,1022,-18,"B6",725,"N804JB","JFK","BQN",183,1576,5,45,2013-01-01 05:00:00
2013,1,1,554,600,-6,812,837,-25,"DL",461,"N668DN","LGA","ATL",116,762,6,0,2013-01-01 06:00:00
2013,1,1,554,558,-4,740,728,12,"UA",1696,"N39463","EWR","ORD",150,719,5,58,2013-01-01 05:00:00
2013,1,1,555,600,-5,913,854,19,"B6",507,"N516JB","EWR","FLL",158,1065,6,0,2013-01-01 06:00:00
2013,1,1,557,600,-3,709,723,-14,"EV",5708,"N829AS","LGA","IAD",53,229,6,0,2013-01-01 06:00:00
2013,1,1,557,600,-3,838,846,-8,"B6",79,"N593JB","JFK","MCO",140,944,6,0,2013-01-01 06:00:00
2013,1,1,558,600,-2,753,745,8,"AA",301,"N3ALAA","LGA","ORD",138,733,6,0,2013-01-01 06:00:00

Überlegungen zur Leistung:

Im Allgemeinen ist das Schwenken eine teure Operation. 

Verwandte Fragen:

52
zero323

Ich habe das überwunden, indem ich eine for-Schleife geschrieben habe, um eine SQL-Abfrage dynamisch zu erstellen. Sag ich habe:

id  tag  value
1   US    50
1   UK    100
1   Can   125
2   US    75
2   UK    150
2   Can   175

und ich will:

id  US  UK   Can
1   50  100  125
2   75  150  175

Ich kann eine Liste mit dem Wert erstellen, den ich schwenken möchte, und dann eine Zeichenfolge mit der SQL-Abfrage erstellen, die ich benötige.

val countries = List("US", "UK", "Can")
val numCountries = countries.length - 1

var query = "select *, "
for (i <- 0 to numCountries-1) {
  query += """case when tag = """" + countries(i) + """" then value else 0 end as """ + countries(i) + ", "
}
query += """case when tag = """" + countries.last + """" then value else 0 end as """ + countries.last + " from myTable"

myDataFrame.registerTempTable("myTable")
val myDF1 = sqlContext.sql(query)

Ich kann eine ähnliche Abfrage erstellen, um die Aggregation durchzuführen. Keine sehr elegante Lösung, aber sie funktioniert und ist flexibel für jede Liste von Werten, die beim Aufruf Ihres Codes auch als Argument übergeben werden kann.

13
J Calbreath

Der Spark-Dataframe-API wurde ein Pivot-Operator hinzugefügt, der Teil von Spark 1.6 ist.

Weitere Informationen finden Sie unter https://github.com/Apache/spark/pull/7841 .

9
David Anderson

Ich habe ein ähnliches Problem mit Dataframes mit den folgenden Schritten gelöst:

Erstellen Sie Spalten für alle Ihre Länder mit dem Wert "value":

import org.Apache.spark.sql.functions._
val countries = List("US", "UK", "Can")
val countryValue = udf{(countryToCheck: String, countryInRow: String, value: Long) =>
  if(countryToCheck == countryInRow) value else 0
}
val countryFuncs = countries.map{country => (dataFrame: DataFrame) => dataFrame.withColumn(country, countryValue(lit(country), df("tag"), df("value"))) }
val dfWithCountries = Function.chain(countryFuncs)(df).drop("tag").drop("value")

Ihr Datenrahmen 'dfWithCountries' sieht folgendermaßen aus:

+--+--+---+---+
|id|US| UK|Can|
+--+--+---+---+
| 1|50|  0|  0|
| 1| 0|100|  0|
| 1| 0|  0|125|
| 2|75|  0|  0|
| 2| 0|150|  0|
| 2| 0|  0|175|
+--+--+---+---+

Nun können Sie alle Werte für Ihr gewünschtes Ergebnis zusammenfassen:

dfWithCountries.groupBy("id").sum(countries: _*).show

Ergebnis:

+--+-------+-------+--------+
|id|SUM(US)|SUM(UK)|SUM(Can)|
+--+-------+-------+--------+
| 1|     50|    100|     125|
| 2|     75|    150|     175|
+--+-------+-------+--------+

Es ist jedoch keine sehr elegante Lösung. Ich musste eine Kette von Funktionen erstellen, um alle Spalten hinzuzufügen. Wenn ich viele Länder habe, werde ich meinen temporären Datensatz auf einen sehr großen Satz mit vielen Nullen erweitern.

5
Al M

Es gibt einfache und elegante Lösung.

scala> spark.sql("select * from k_tags limit 10").show()
+---------------+-------------+------+
|           imsi|         name| value|
+---------------+-------------+------+
|246021000000000|          age|    37|
|246021000000000|       gender|Female|
|246021000000000|         arpu|    22|
|246021000000000|   DeviceType| Phone|
|246021000000000|DataAllowance|   6GB|
+---------------+-------------+------+

scala> spark.sql("select * from k_tags limit 10").groupBy($"imsi").pivot("name").agg(min($"value")).show()
+---------------+-------------+----------+---+----+------+
|           imsi|DataAllowance|DeviceType|age|arpu|gender|
+---------------+-------------+----------+---+----+------+
|246021000000000|          6GB|     Phone| 37|  22|Female|
|246021000000001|          1GB|     Phone| 72|  10|  Male|
+---------------+-------------+----------+---+----+------+
1
Mantas

Es gibt viele Beispiele für Pivot-Operationen mit Dataset/Dataframe, aber ich konnte mit SQL nicht viele finden. Hier ist ein Beispiel, das für mich funktioniert hat.

create or replace temporary view faang 
as SELECT stock.date AS `Date`,
    stock.adj_close AS `Price`,
    stock.symbol as `Symbol` 
FROM stock  
WHERE (stock.symbol rlike '^(FB|AAPL|GOOG|AMZN)$') and year(date) > 2010;


SELECT * from faang 

PIVOT (max(price) for symbol in ('AAPL', 'FB', 'GOOG', 'AMZN')) order by date; 

0
abasar

Anfangs habe ich die Lösung von Al M übernommen. Später folgte der gleiche Gedanke und schrieb diese Funktion als Transponierfunktion um.

Diese Methode setzt alle df-Zeilen in Spalten eines beliebigen Datenformats um, wobei Schlüssel und Wert verwendet werden

für Eingabe-CSV

id,tag,value
1,US,50a
1,UK,100
1,Can,125
2,US,75
2,UK,150
2,Can,175

Ausgabe

+--+---+---+---+
|id| UK| US|Can|
+--+---+---+---+
| 2|150| 75|175|
| 1|100|50a|125|
+--+---+---+---+

Transponierungsmethode:

def transpose(hc : HiveContext , df: DataFrame,compositeId: List[String], key: String, value: String) = {

val distinctCols =   df.select(key).distinct.map { r => r(0) }.collect().toList

val rdd = df.map { row =>
(compositeId.collect { case id => row.getAs(id).asInstanceOf[Any] },
scala.collection.mutable.Map(row.getAs(key).asInstanceOf[Any] -> row.getAs(value).asInstanceOf[Any]))
}
val pairRdd = rdd.reduceByKey(_ ++ _)
val rowRdd = pairRdd.map(r => dynamicRow(r, distinctCols))
hc.createDataFrame(rowRdd, getSchema(df.schema, compositeId, (key, distinctCols)))

}

private def dynamicRow(r: (List[Any], scala.collection.mutable.Map[Any, Any]), colNames: List[Any]) = {
val cols = colNames.collect { case col => r._2.getOrElse(col.toString(), null) }
val array = r._1 ++ cols
Row(array: _*)
}

private  def getSchema(srcSchema: StructType, idCols: List[String], distinctCols: (String, List[Any])): StructType = {
val idSchema = idCols.map { idCol => srcSchema.apply(idCol) }
val colSchema = srcSchema.apply(distinctCols._1)
val colsSchema = distinctCols._2.map { col => StructField(col.asInstanceOf[String], colSchema.dataType, colSchema.nullable) }
StructType(idSchema ++ colsSchema)
}

Hauptausschnitt

import Java.util.Date
import org.Apache.spark.SparkConf
import org.Apache.spark.SparkContext
import org.Apache.spark.sql.Row
import org.Apache.spark.sql.DataFrame
import org.Apache.spark.sql.types.StructType
import org.Apache.spark.sql.Hive.HiveContext
import org.Apache.spark.sql.types.StructField


...
...
def main(args: Array[String]): Unit = {

    val sc = new SparkContext(conf)
    val sqlContext = new org.Apache.spark.sql.SQLContext(sc)
    val dfdata1 = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true")
    .load("data.csv")
    dfdata1.show()  
    val dfOutput = transpose(new HiveContext(sc), dfdata1, List("id"), "tag", "value")
    dfOutput.show

}
0
Jaigates

Spark hat das Pivoting des Spark DataFrame verbessert. Der Spark DataFrame API wurde eine Pivot-Funktion zu Spark hinzugefügt = 1.6 Version und es hat ein Leistungsproblem und das wurde korrigiert in Spark 2.0

wenn Sie jedoch eine niedrigere Version verwenden; Beachten Sie, dass Pivot eine sehr teure Operation ist. Daher wird empfohlen, Spaltendaten (falls bekannt) als Argument anzugeben, um wie unten gezeigt zu funktionieren.

val countries = Seq("USA","China","Canada","Mexico")
val pivotDF = df.groupBy("Product").pivot("Country", countries).sum("Amount")
pivotDF.show()

Dies wurde unter Pivoting and Unpivoting Spark DataFrame ausführlich erläutert

Viel Spaß beim Lernen !!

0
Naveen Nelamali

Die eingebaute spark pivot-Funktion ist ineffizient. Die folgende Implementierung funktioniert mit spark 2.4+ - die Idee ist, eine Karte zu aggregieren und die Werte als Spalten zu extrahieren Die einzige Einschränkung besteht darin, dass die Aggregatfunktion in den geschwenkten Spalten nicht verarbeitet wird, sondern nur in den Spalten.

In einer 8M-Tabelle gelten diese Funktionen für Sekunden im Vergleich zu 40 Minuten in der eingebauten spark version:

# pass an optional list of string to avoid computation of columns
def pivot(df, group_by, key, aggFunction, levels=[]):
    if not levels:
        levels = [row[key] for row in df.filter(col(key).isNotNull()).groupBy(col(key)).agg(count(key)).select(key).collect()]
    return df.filter(col(key).isin(*levels) == True).groupBy(group_by).agg(map_from_entries(collect_list(struct(key, expr(aggFunction)))).alias("group_map")).select([group_by] + ["group_map." + l for l in levels])

# Usage
pivot(df, "id", "key", "value")
pivot(df, "id", "key", "array(value)")
// pass an optional list of string to avoid computation of columns
  def pivot(df: DataFrame, groupBy: Column, key: Column, aggFunct: String, _levels: List[String] = Nil): DataFrame = {
    val levels =
      if (_levels.isEmpty) df.filter(key.isNotNull).select(key).distinct().collect().map(row => row.getString(0)).toList
      else _levels

    df
      .filter(key.isInCollection(levels))
      .groupBy(groupBy)
      .agg(map_from_entries(collect_list(struct(key, expr(aggFunct)))).alias("group_map"))
      .select(groupBy.toString, levels.map(f => "group_map." + f): _*)
  }

// Usage:
pivot(df, col("id"), col("key"), "value")
pivot(df, col("id"), col("key"), "array(value)")
0
parisni