一. 需求
1.1 需求简介
这里的热门商品是从点击量的维度来看的.
计算各个区域前三大热门商品,并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示。
1.2 思路分析
使用 sql 来完成. 碰到复杂的需求, 可以使用 udf 或 udaf
查询出来所有的点击记录, 并与 city_info 表连接, 得到每个城市所在的地区. 与 Product_info 表连接得到产品名称
按照地区和商品 id 分组, 统计出每个商品在每个地区的总点击次数
每个地区内按照点击次数降序排列
只取前三名. 并把结果保存在数据库中
城市备注需要自定义 UDAF 函数
二. 实际操作
1. 准备数据
  我们这次 Spark-sql 操作中所有的数据均来自 Hive.
  首先在 Hive 中创建表, 并导入数据.
  一共有 3 张表: 1 张用户行为表, 1 张城市表, 1 张产品表
1. 打开Hive
2. 创建三个表
[mw_shl_code=sql,true]CREATE TABLE `user_visit_action`(
`date` string,
`user_id` bigint,
`session_id` string,
`page_id` bigint,
`action_time` string,
`search_keyword` string,
`click_category_id` bigint,
`click_product_id` bigint,
`order_category_ids` string,
`order_product_ids` string,
`pay_category_ids` string,
`pay_product_ids` string,
`city_id` bigint)
row format delimited fields terminated by '\t';
CREATE TABLE `product_info`(
`product_id` bigint,
`product_name` string,
`extend_info` string)
row format delimited fields terminated by '\t';
CREATE TABLE `city_info`(
`city_id` bigint,
`city_name` string,
`area` string)
row format delimited fields terminated by '\t';
[/mw_shl_code]
3. 上传数据
[mw_shl_code=bash,true]load data local inpath '/opt/module/datas/user_visit_action.txt' into table spark0806.user_visit_action;
load data local inpath '/opt/module/datas/product_info.txt' into table spark0806.product_info;
load data local inpath '/opt/module/datas/city_info.txt' into table spark0806.city_info;
[/mw_shl_code]
4. 测试是否上传成功
[mw_shl_code=bash,true]hive> select * from city_info;
[/mw_shl_code]
2. 显示各区域热门商品 Top3
[mw_shl_code=sql,true]
// user_visit_action product_info city_info
1. 先把需要的字段查出来 t1
select
ci.*,
pi.product_name,
click_product_id
from user_visit_action uva
join product_info pi on uva.click_product_id=pi.product_id
join city_info ci on uva.city_id=ci.city_id
2. 按照地区和商品名称聚合
select
area,
product_name,
count(*) count
from t1
group by area , product_name
3. 按照地区进行分组开窗 排序 开窗函数 t3 // (rank(1 2 2 4 5...) row_number(1 2 3 4...) dense_rank(1 2 2 3 4...))
select
area,
product_name,
count,
rank() over(partition by area order by count desc)
from t2
4. 过滤出来名次小于等于3的
select
area,
product_name,
count
from t3
where rk <=3
[/mw_shl_code]
2. 运行结果
3. 定义udaf函数 得到需求结果
[mw_shl_code=java,true]package com.buwenbuhuo.spark.sql.project
import java.text.DecimalFormat
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
**
*
* @author 不温卜火
* *
* @create 2020-08-06 13:24
**
* MyCSDN : https://buwenbuhuo.blog.csdn.net/
*
*/
class CityRemarkUDAF extends UserDefinedAggregateFunction {
// 输入数据的类型: 北京 String
override def inputSchema: StructType = {
StructType(Array(StructField("city", StringType)))
}
// 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000 Map, 总的点击量 1000/?
override def bufferSchema: StructType = {
StructType(Array(StructField("map", MapType(StringType, LongType)), StructField("total", LongType)))
}
// 输出的数据类型 "北京21.2%,天津13.2%,其他65.6%" String
override def dataType: DataType = StringType
// 相同的输入是否应用有相同的输出.
override def deterministic: Boolean = true
// 给存储数据初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//初始化map缓存
buffer(0) = Map[String, Long]()
// 初始化总的点击量
buffer(1) = 0L
}
// 分区内合并 Map[城市名, 点击量]
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
input match {
case Row(cityName: String) =>
// 1. 总的点击量 + 1
buffer(1) = buffer.getLong(1) + 1L
// 2. 给这个城市的点击量 +1 => 找到缓冲区的map,取出来这个城市原来的点击 + 1 ,再复制过去
val map: collection.Map[String, Long] = buffer.getMap[String, Long](0)
buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L))
case _ =>
}
}
// 分区间的合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val map1 = buffer1.getAs[Map[String, Long]](0)
val map2 = buffer2.getAs[Map[String, Long]](0)
val total1: Long = buffer1.getLong(1)
val total2: Long = buffer2.getLong(1)
// 1. 总数的聚合
buffer1(1) = total1 + total2
// 2. map的聚合
buffer1(0) = map1.foldLeft(map2) {
case (map, (cityName, count)) =>
map + (cityName -> (map.getOrElse(cityName, 0L) + count))
}
}
// 最终的输出结果
override def evaluate(buffer: Row): Any = {
// "北京21.2%,天津13.2%,其他65.6%"
val cityAndCount: collection.Map[String, Long] = buffer.getMap[String, Long](0)
val total: Long = buffer.getLong(1)
val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(2)
var cityRemarks: List[CityRemark] = cityCountTop2.map {
case (cityName, count) => CityRemark(cityName, count.toDouble / total)
}
// CityRemark("其他",1 - cityremarks.foldLeft(0D)(_+_.cityRatio))
cityRemarks :+= CityRemark("其他",cityRemarks.foldLeft(1D)(_ - _.cityRatio))
cityRemarks.mkString(",")
}
}
case class CityRemark(cityName: String, cityRatio: Double) {
val formatter = new DecimalFormat("0.00%")
override def toString: String = s"$cityName:${formatter.format(cityRatio)}"
}
[/mw_shl_code]
运行结果
4 .保存到Mysql
1. 源码
[mw_shl_code=java,true] val props: Properties = new Properties()
props.put("user","root")
props.put("password","199712")
spark.sql(
"""
|select
| area,
| product_name,
| count,
| remark
|from t3
|where rk<=3
|""".stripMargin)
.coalesce(1)
.write
.mode("overwrite")
.jdbc("jdbc:mysql://hadoop002:3306/rdd?useUnicode=true&characterEncoding=utf8", "spark0806", props)
[/mw_shl_code]
2.运行结果
三. 完整代码
1. udaf
[mw_shl_code=java,true]package com.buwenbuhuo.spark.sql.project
import java.text.DecimalFormat
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
**
*
* @author 不温卜火
* *
* @create 2020-08-06 13:24
**
* MyCSDN : https://buwenbuhuo.blog.csdn.net/
*
*/
class CityRemarkUDAF extends UserDefinedAggregateFunction {
// 输入数据的类型: 北京 String
override def inputSchema: StructType = {
StructType(Array(StructField("city", StringType)))
}
// 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000 Map, 总的点击量 1000/?
override def bufferSchema: StructType = {
StructType(Array(StructField("map", MapType(StringType, LongType)), StructField("total", LongType)))
}
// 输出的数据类型 "北京21.2%,天津13.2%,其他65.6%" String
override def dataType: DataType = StringType
// 相同的输入是否应用有相同的输出.
override def deterministic: Boolean = true
// 给存储数据初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//初始化map缓存
buffer(0) = Map[String, Long]()
// 初始化总的点击量
buffer(1) = 0L
}
// 分区内合并 Map[城市名, 点击量]
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
input match {
case Row(cityName: String) =>
// 1. 总的点击量 + 1
buffer(1) = buffer.getLong(1) + 1L
// 2. 给这个城市的点击量 +1 => 找到缓冲区的map,取出来这个城市原来的点击 + 1 ,再复制过去
val map: collection.Map[String, Long] = buffer.getMap[String, Long](0)
buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L))
case _ =>
}
}
// 分区间的合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val map1 = buffer1.getAs[Map[String, Long]](0)
val map2 = buffer2.getAs[Map[String, Long]](0)
val total1: Long = buffer1.getLong(1)
val total2: Long = buffer2.getLong(1)
// 1. 总数的聚合
buffer1(1) = total1 + total2
// 2. map的聚合
buffer1(0) = map1.foldLeft(map2) {
case (map, (cityName, count)) =>
map + (cityName -> (map.getOrElse(cityName, 0L) + count))
}
}
// 最终的输出结果
override def evaluate(buffer: Row): Any = {
// "北京21.2%,天津13.2%,其他65.6%"
val cityAndCount: collection.Map[String, Long] = buffer.getMap[String, Long](0)
val total: Long = buffer.getLong(1)
val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(2)
var cityRemarks: List[CityRemark] = cityCountTop2.map {
case (cityName, count) => CityRemark(cityName, count.toDouble / total)
}
// CityRemark("其他",1 - cityremarks.foldLeft(0D)(_+_.cityRatio))
cityRemarks :+= CityRemark("其他",cityRemarks.foldLeft(1D)(_ - _.cityRatio))
cityRemarks.mkString(",")
}
}
case class CityRemark(cityName: String, cityRatio: Double) {
val formatter = new DecimalFormat("0.00%")
override def toString: String = s"$cityName:${formatter.format(cityRatio)}"
}
[/mw_shl_code]
2. 主程序(具体实现)
[mw_shl_code=java,true]package com.buwenbuhuo.spark.sql.project
import java.util.Properties
import org.apache.spark.sql.SparkSession
/**
**
*
* @author 不温卜火
* *
* @create 2020-08-05 19:01
**
* MyCSDN : https://buwenbuhuo.blog.csdn.net/
*
*/
object SqlApp {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession
.builder()
.master("local")
.appName("SqlApp")
.enableHiveSupport()
.getOrCreate()
import spark.implicits._
spark.udf.register("remark",new CityRemarkUDAF)
// 去执行sql,从hive查询数据
spark.sql("use spark0806")
spark.sql(
"""
|select
| ci.*,
| pi.product_name,
| uva.click_product_id
|from user_visit_action uva
|join product_info pi on uva.click_product_id=pi.product_id
|join city_info ci on uva.city_id=ci.city_id
|
|""".stripMargin).createOrReplaceTempView("t1")
spark.sql(
"""
|select
| area,
| product_name,
| count(*) count,
| remark(city_name) remark
|from t1
|group by area, product_name
|""".stripMargin).createOrReplaceTempView("t2")
spark.sql(
"""
|select
| area,
| product_name,
| count,
| remark,
| rank() over(partition by area order by count desc) rk
|from t2
|""".stripMargin).createOrReplaceTempView("t3")
val props: Properties = new Properties()
props.put("user","root")
props.put("password","199712")
spark.sql(
"""
|select
| area,
| product_name,
| count,
| remark
|from t3
|where rk<=3
|""".stripMargin)
.coalesce(1)
.write
.mode("overwrite")
.jdbc("jdbc:mysql://hadoop002:3306/rdd?useUnicode=true&characterEncoding=utf8", "spark0806", props)
// 把结果写入到mysql中
spark.close()
}
}
[/mw_shl_code]
原文链接:
https://blog.csdn.net/qq_16146103/article/details/107824095
作者:不温卜火
|