分享

Spark SQL 项目:实现各区域热门商品前N统计

一. 需求

1.1 需求简介
这里的热门商品是从点击量的维度来看的.

计算各个区域前三大热门商品,并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示。


1.png

1.2 思路分析
使用 sql 来完成. 碰到复杂的需求, 可以使用 udf 或 udaf

查询出来所有的点击记录, 并与 city_info 表连接, 得到每个城市所在的地区. 与 Product_info 表连接得到产品名称
按照地区和商品 id 分组, 统计出每个商品在每个地区的总点击次数
每个地区内按照点击次数降序排列
只取前三名. 并把结果保存在数据库中
城市备注需要自定义 UDAF 函数



二. 实际操作
1. 准备数据
  我们这次 Spark-sql 操作中所有的数据均来自 Hive.

  首先在 Hive 中创建表, 并导入数据.

  一共有 3 张表: 1 张用户行为表, 1 张城市表, 1 张产品表

1. 打开Hive


1.png

2. 创建三个表
[mw_shl_code=sql,true]CREATE TABLE `user_visit_action`(
  `date` string,
  `user_id` bigint,
  `session_id` string,
  `page_id` bigint,
  `action_time` string,
  `search_keyword` string,
  `click_category_id` bigint,
  `click_product_id` bigint,
  `order_category_ids` string,
  `order_product_ids` string,
  `pay_category_ids` string,
  `pay_product_ids` string,
  `city_id` bigint)
row format delimited fields terminated by '\t';

CREATE TABLE `product_info`(
  `product_id` bigint,
  `product_name` string,
  `extend_info` string)
row format delimited fields terminated by '\t';

CREATE TABLE `city_info`(
  `city_id` bigint,
  `city_name` string,
  `area` string)
row format delimited fields terminated by '\t';

[/mw_shl_code]

1.png

3. 上传数据

[mw_shl_code=bash,true]load data local inpath '/opt/module/datas/user_visit_action.txt' into table spark0806.user_visit_action;
load data local inpath '/opt/module/datas/product_info.txt' into table spark0806.product_info;
load data local inpath '/opt/module/datas/city_info.txt' into table spark0806.city_info;
[/mw_shl_code]

1.png

4. 测试是否上传成功

[mw_shl_code=bash,true]hive> select * from city_info;
[/mw_shl_code]

1.png

2. 显示各区域热门商品 Top3

[mw_shl_code=sql,true]
// user_visit_action  product_info  city_info

1. 先把需要的字段查出来   t1
select
    ci.*,
    pi.product_name,
    click_product_id
from user_visit_action uva
join product_info pi on uva.click_product_id=pi.product_id
join city_info ci on uva.city_id=ci.city_id

2. 按照地区和商品名称聚合
select
    area,
    product_name,
    count(*)  count
from t1
group by area , product_name

3. 按照地区进行分组开窗 排序 开窗函数 t3 // (rank(1 2 2 4 5...) row_number(1 2 3 4...) dense_rank(1 2 2 3 4...))
select
    area,
    product_name,
    count,
    rank() over(partition by area order by count desc)
from  t2


4. 过滤出来名次小于等于3的
select
    area,
    product_name,
    count
from  t3
where rk <=3
[/mw_shl_code]

2. 运行结果


1.png

3. 定义udaf函数 得到需求结果

[mw_shl_code=java,true]package com.buwenbuhuo.spark.sql.project

import java.text.DecimalFormat

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


/**
**
*
* @author 不温卜火
*         *
* @create 2020-08-06 13:24
**
*         MyCSDN :  https://buwenbuhuo.blog.csdn.net/
*
*/
class CityRemarkUDAF extends UserDefinedAggregateFunction {
  // 输入数据的类型:  北京  String
  override def inputSchema: StructType = {
    StructType(Array(StructField("city", StringType)))
  }

  // 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000  Map,  总的点击量  1000/?
  override def bufferSchema: StructType = {
    StructType(Array(StructField("map", MapType(StringType, LongType)), StructField("total", LongType)))
  }

  // 输出的数据类型  "北京21.2%,天津13.2%,其他65.6%"  String
  override def dataType: DataType = StringType

  // 相同的输入是否应用有相同的输出.
  override def deterministic: Boolean = true

  // 给存储数据初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化map缓存
    buffer(0) = Map[String, Long]()
    // 初始化总的点击量
    buffer(1) = 0L
  }

  // 分区内合并 Map[城市名, 点击量]
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    input match {
      case Row(cityName: String) =>
        // 1. 总的点击量 + 1
        buffer(1) = buffer.getLong(1) + 1L
        // 2. 给这个城市的点击量 +1 =>   找到缓冲区的map,取出来这个城市原来的点击 + 1 ,再复制过去
        val map: collection.Map[String, Long] = buffer.getMap[String, Long](0)
        buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L))
      case _ =>
    }
  }

  // 分区间的合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map1 = buffer1.getAs[Map[String, Long]](0)
    val map2 = buffer2.getAs[Map[String, Long]](0)

    val total1: Long = buffer1.getLong(1)
    val total2: Long = buffer2.getLong(1)

    // 1. 总数的聚合
    buffer1(1) = total1 + total2
    // 2. map的聚合
    buffer1(0) = map1.foldLeft(map2) {
      case (map, (cityName, count)) =>
        map + (cityName -> (map.getOrElse(cityName, 0L) + count))
    }

  }

  // 最终的输出结果
  override def evaluate(buffer: Row): Any = {
    // "北京21.2%,天津13.2%,其他65.6%"
    val cityAndCount: collection.Map[String, Long] = buffer.getMap[String, Long](0)
    val total: Long = buffer.getLong(1)

    val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(2)
    var cityRemarks: List[CityRemark] = cityCountTop2.map {
      case (cityName, count) => CityRemark(cityName, count.toDouble / total)
    }
//    CityRemark("其他",1 - cityremarks.foldLeft(0D)(_+_.cityRatio))
    cityRemarks :+= CityRemark("其他",cityRemarks.foldLeft(1D)(_ - _.cityRatio))
    cityRemarks.mkString(",")
  }
}

case class CityRemark(cityName: String, cityRatio: Double) {
  val formatter = new DecimalFormat("0.00%")

  override def toString: String = s"$cityName:${formatter.format(cityRatio)}"

}

[/mw_shl_code]

运行结果


1.png

4 .保存到Mysql

1. 源码

[mw_shl_code=java,true]    val props: Properties = new Properties()
    props.put("user","root")
    props.put("password","199712")
    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count,
        |    remark
        |from t3
        |where rk<=3
        |""".stripMargin)
      .coalesce(1)
      .write
      .mode("overwrite")
      .jdbc("jdbc:mysql://hadoop002:3306/rdd?useUnicode=true&characterEncoding=utf8", "spark0806", props)
[/mw_shl_code]

2.运行结果


1.png


三. 完整代码

1. udaf

[mw_shl_code=java,true]package com.buwenbuhuo.spark.sql.project

import java.text.DecimalFormat

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


/**
**
*
* @author 不温卜火
*         *
* @create 2020-08-06 13:24
**
*         MyCSDN :  https://buwenbuhuo.blog.csdn.net/
*
*/
class CityRemarkUDAF extends UserDefinedAggregateFunction {
  // 输入数据的类型:  北京  String
  override def inputSchema: StructType = {
    StructType(Array(StructField("city", StringType)))
  }

  // 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000  Map,  总的点击量  1000/?
  override def bufferSchema: StructType = {
    StructType(Array(StructField("map", MapType(StringType, LongType)), StructField("total", LongType)))
  }

  // 输出的数据类型  "北京21.2%,天津13.2%,其他65.6%"  String
  override def dataType: DataType = StringType

  // 相同的输入是否应用有相同的输出.
  override def deterministic: Boolean = true

  // 给存储数据初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化map缓存
    buffer(0) = Map[String, Long]()
    // 初始化总的点击量
    buffer(1) = 0L
  }

  // 分区内合并 Map[城市名, 点击量]
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    input match {
      case Row(cityName: String) =>
        // 1. 总的点击量 + 1
        buffer(1) = buffer.getLong(1) + 1L
        // 2. 给这个城市的点击量 +1 =>   找到缓冲区的map,取出来这个城市原来的点击 + 1 ,再复制过去
        val map: collection.Map[String, Long] = buffer.getMap[String, Long](0)
        buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L))
      case _ =>
    }
  }

  // 分区间的合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map1 = buffer1.getAs[Map[String, Long]](0)
    val map2 = buffer2.getAs[Map[String, Long]](0)

    val total1: Long = buffer1.getLong(1)
    val total2: Long = buffer2.getLong(1)

    // 1. 总数的聚合
    buffer1(1) = total1 + total2
    // 2. map的聚合
    buffer1(0) = map1.foldLeft(map2) {
      case (map, (cityName, count)) =>
        map + (cityName -> (map.getOrElse(cityName, 0L) + count))
    }

  }

  // 最终的输出结果
  override def evaluate(buffer: Row): Any = {
    // "北京21.2%,天津13.2%,其他65.6%"
    val cityAndCount: collection.Map[String, Long] = buffer.getMap[String, Long](0)
    val total: Long = buffer.getLong(1)

    val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(2)
    var cityRemarks: List[CityRemark] = cityCountTop2.map {
      case (cityName, count) => CityRemark(cityName, count.toDouble / total)
    }
//    CityRemark("其他",1 - cityremarks.foldLeft(0D)(_+_.cityRatio))
    cityRemarks :+= CityRemark("其他",cityRemarks.foldLeft(1D)(_ - _.cityRatio))
    cityRemarks.mkString(",")
  }
}

case class CityRemark(cityName: String, cityRatio: Double) {
  val formatter = new DecimalFormat("0.00%")

  override def toString: String = s"$cityName:${formatter.format(cityRatio)}"

}

[/mw_shl_code]

2. 主程序(具体实现)


[mw_shl_code=java,true]package com.buwenbuhuo.spark.sql.project

import java.util.Properties

import org.apache.spark.sql.SparkSession

/**
**
*
* @author 不温卜火
*         *
* @create 2020-08-05 19:01
**
*         MyCSDN :  https://buwenbuhuo.blog.csdn.net/
*
*/
object SqlApp {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .master("local")
      .appName("SqlApp")
      .enableHiveSupport()
      .getOrCreate()
    import spark.implicits._
    spark.udf.register("remark",new CityRemarkUDAF)

    // 去执行sql,从hive查询数据
    spark.sql("use spark0806")
    spark.sql(
      """
        |select
        |    ci.*,
        |    pi.product_name,
        |    uva.click_product_id
        |from user_visit_action uva
        |join product_info pi on uva.click_product_id=pi.product_id
        |join city_info ci on uva.city_id=ci.city_id
        |
        |""".stripMargin).createOrReplaceTempView("t1")

    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count(*) count,
        |    remark(city_name) remark
        |from t1
        |group by area, product_name
        |""".stripMargin).createOrReplaceTempView("t2")

    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count,
        |    remark,
        |    rank() over(partition by area order by count desc) rk
        |from t2
        |""".stripMargin).createOrReplaceTempView("t3")

    val props: Properties = new Properties()
    props.put("user","root")
    props.put("password","199712")
    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count,
        |    remark
        |from t3
        |where rk<=3
        |""".stripMargin)
      .coalesce(1)
      .write
      .mode("overwrite")
      .jdbc("jdbc:mysql://hadoop002:3306/rdd?useUnicode=true&characterEncoding=utf8", "spark0806", props)


    // 把结果写入到mysql中

    spark.close()
  }
}


[/mw_shl_code]


原文链接:
https://blog.csdn.net/qq_16146103/article/details/107824095

作者:不温卜火

加微信w3aboutyun,可拉入技术爱好者群

没找到任何评论,期待你打破沉寂

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

关闭

推荐上一条 /2 下一条