2015-11-23 20 views
28

czy istnieje sposób zastosowania funkcji agregującej do wszystkich (lub listy) ramek danych, gdy robisz grupę przez? Innymi słowy, czy istnieje sposób, aby tego uniknąć w przypadku każdej kolumny:SparkSQL: zastosować funkcje agregujące do listy kolumn

df.groupBy("col1") 
.agg(sum("col2").alias("col2"), sum("col3").alias("col3"), ...) 

wielkie dzięki!

Odpowiedz

55

Istnieje wiele sposobów zastosowania funkcji agregujących w wielu kolumnach.

GroupedData klasa zawiera szereg sposobów najczęściej funkcji, w tym count, max, min, mean i sum, które mogą być stosowane bezpośrednio jako następujące:

  • Pyton:

    df = sqlContext.createDataFrame(
        [(1.0, 0.3, 1.0), (1.0, 0.5, 0.0), (-1.0, 0.6, 0.5), (-1.0, 5.6, 0.2)], 
        ("col1", "col2", "col3")) 
    
    df.groupBy("col1").sum() 
    
    ## +----+---------+-----------------+---------+ 
    ## |col1|sum(col1)|  sum(col2)|sum(col3)| 
    ## +----+---------+-----------------+---------+ 
    ## | 1.0|  2.0|    0.8|  1.0| 
    ## |-1.0|  -2.0|6.199999999999999|  0.7| 
    ## +----+---------+-----------------+---------+ 
    
  • Scala

    val df = sc.parallelize(Seq(
        (1.0, 0.3, 1.0), (1.0, 0.5, 0.0), 
        (-1.0, 0.6, 0.5), (-1.0, 5.6, 0.2)) 
    ).toDF("col1", "col2", "col3") 
    
    df.groupBy($"col1").min().show 
    
    // +----+---------+---------+---------+ 
    // |col1|min(col1)|min(col2)|min(col3)| 
    // +----+---------+---------+---------+ 
    // | 1.0|  1.0|  0.3|  0.0| 
    // |-1.0|  -1.0|  0.6|  0.2| 
    // +----+---------+---------+---------+ 
    

Opcjonalnie można przekazać listę kolumn, które powinny być agregowane

df.groupBy("col1").sum("col2", "col3") 

Można także przekazać słownika/mapę z kolumn A klawisze i funkcje wartości:

  • Python

    exprs = {x: "sum" for x in df.columns} 
    df.groupBy("col1").agg(exprs).show() 
    
    ## +----+---------+ 
    ## |col1|avg(col3)| 
    ## +----+---------+ 
    ## | 1.0|  0.5| 
    ## |-1.0|  0.35| 
    ## +----+---------+ 
    
  • Scala

    val exprs = df.columns.map((_ -> "mean")).toMap 
    df.groupBy($"col1").agg(exprs).show() 
    
    // +----+---------+------------------+---------+ 
    // |col1|avg(col1)|   avg(col2)|avg(col3)| 
    // +----+---------+------------------+---------+ 
    // | 1.0|  1.0|    0.4|  0.5| 
    // |-1.0|  -1.0|3.0999999999999996|  0.35| 
    // +----+---------+------------------+---------+ 
    

Wreszcie można użyć varargs:

  • Python

    from pyspark.sql.functions import min 
    
    exprs = [min(x) for x in df.columns] 
    df.groupBy("col1").agg(*exprs).show() 
    
  • Scala

    import org.apache.spark.sql.functions.sum 
    
    val exprs = df.columns.map(sum(_)) 
    df.groupBy($"col1").agg(exprs.head, exprs.tail: _*) 
    

Istnieje kilka innych sposobów osiągnięcia podobnego efektu, ale powinny one wystarczyć na większość czasu.

+0

wydaje 'aggregateBy' byłoby tu zastosowanie. Jest szybszy (znacznie szybciej) niż 'groupBy'.Och, czekaj - 'DataFrame' nie eksponuje' aggregateBy' - 'agg' wskazuje na' groupBy'. Cóż, to oznacza, że ​​'DataFrames' są * wolne * .. – javadba

+0

@javadba Nie, to oznacza tylko, że' Dataset.groupBy'/'Dataset.groupByKey' i' RDD.groupBy'/'RDD.groupByKey' mają, w ogólnym przypadku, inna semantyka. W przypadku prostych agregacji 'DataFrame' [sprawdź to] (http://stackoverflow.com/a/32903568/1560062). Jest w tym coś więcej, ale nie jest to tutaj ważne. – zero323

+0

Ładne informacje! przegłosował inną odpowiedź – javadba

6

Innym przykładem tego samego pojęcia - ale powiedzieć - masz 2 różnych kolumn - i chcesz zastosować różne funkcje AGG do każdego z nich, tj

f.groupBy("col1").agg(sum("col2").alias("col2"), avg("col3").alias("col3"), ...) 

Oto sposób, aby go osiągnąć - choć jeszcze nie wiem jak dodać alias w tym przypadku

patrz przykład poniżej - za pomocą mapy

val Claim1 = StructType(Seq(StructField("pid", StringType, true),StructField("diag1", StringType, true),StructField("diag2", StringType, true), StructField("allowed", IntegerType, true), StructField("allowed1", IntegerType, true))) 
val claimsData1 = Seq(("PID1", "diag1", "diag2", 100, 200), ("PID1", "diag2", "diag3", 300, 600), ("PID1", "diag1", "diag5", 340, 680), ("PID2", "diag3", "diag4", 245, 490), ("PID2", "diag2", "diag1", 124, 248)) 

val claimRDD1 = sc.parallelize(claimsData1) 
val claimRDDRow1 = claimRDD1.map(p => Row(p._1, p._2, p._3, p._4, p._5)) 
val claimRDD2DF1 = sqlContext.createDataFrame(claimRDDRow1, Claim1) 

val l = List("allowed", "allowed1") 
val exprs = l.map((_ -> "sum")).toMap 
claimRDD2DF1.groupBy("pid").agg(exprs) show false 
val exprs = Map("allowed" -> "sum", "allowed1" -> "avg") 

claimRDD2DF1.groupBy("pid").agg(exprs) show false