HIVE的row_number函数,类似于Oracle的ROW_NUMBER函数,实现在HIVE跑Map/Reduce的Reduce过程中取行号,一般应用于Sort By,Order By
[mw_shl_code=java,true]import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.LongWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@Description(name = "row_number", value = "_FUNC_(a, [...]) - Assumes that incoming data is SORTed and DISTRIBUTEd according to the given columns, and then returns the row number for each row within the partition,")
public class GenericUDFPartitionRowNumber extends GenericUDF {
private Logger logger = LoggerFactory.getLogger(GenericUDFPartitionRowNumber.class);
private LongWritable rowIndex = new LongWritable(0);
private Object[] partitionColumnValues;
private ObjectInspector[] objectInspectors;
private int[] sortDirections; // holds +1 (for compare() > 0), 0 for unknown, -1 (for compare() < 0)
* Takes the output of compare() and scales it to either, +1, 0 or -1.
* @param val
* @return
protected static int collapseToIndicator(int val) {
if (val > 0) {
return 1;
} else if (val == 0) {
return 0;
} else {
return -1;
* Wraps Object.equals, but allows one or both arguments to be null. Note
* that nullSafeEquals(null, null) == true.
* @param o1
* First object
* @param o2
* Second object
* @return
protected static boolean nullSafeEquals(Object o1, Object o2) {
if (o1 == null && o2 == null) {
return true;
} else if (o1 == null || o2 == null) {
return false;
} else {
return (o1.equals(o2));
public Object evaluate(DeferredObject[] arguments) throws HiveException {
assert (arguments.length == partitionColumnValues.length);
for (int i = 0; i < arguments.length; i++) {
if (partitionColumnValues == null) {
partitionColumnValues = ObjectInspectorUtils.copyToStandardObject(arguments.get(), objectInspectors);
} else if (!nullSafeEquals(arguments.get(), partitionColumnValues)) {
// check sort directions. We know the elements aren't equal.
int newDirection = collapseToIndicator(ObjectInspectorUtils.compare(arguments.get(), objectInspectors,partitionColumnValues, objectInspectors));
if (sortDirections == 0) { // We don't already know what the sort direction should be
sortDirections = newDirection;
} else if (sortDirections != newDirection) {
throw new HiveException( "Data in column: " + i
+ " does not appear to be consistently sorted, so partitionedRowNumber cannot be used.");
// reset everything (well, the remaining column values, because the previous ones haven't changed.
for (int j = i; j < arguments.length; j++) {
partitionColumnValues[j] = ObjectInspectorUtils.copyToStandardObject(arguments[j].get(),objectInspectors[j]);
return rowIndex;
// partition columns are identical. Increment and continue.
rowIndex.set(rowIndex.get() + 1);
return rowIndex;
public String getDisplayString(String[] children) {
return "partitionedRowNumber(" + StringUtils.join(children, ", ") + ")";
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
logger.info("run times");
if (arguments.length == 0) {
throw new UDFArgumentLengthException("The function partitionedRowNumber expects at least 1 argument.");
partitionColumnValues = new Object[arguments.length];
for (ObjectInspector oi : arguments) {
if (ObjectInspectorUtils.isConstantObjectInspector(oi)) {
throw new UDFArgumentException("No constant arguments should be passed to partitionedRowNumber.");
objectInspectors = arguments;
sortDirections = new int[arguments.length];
return PrimitiveObjectInspectorFactory.writableLongObjectInspector;
[mw_shl_code=java,true]import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.exec.WindowFunctionDescription;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.IntWritable;
description = @Description(
name = "row_number",
value = "_FUNC_() - The ROW_NUMBER function assigns a unique number (sequentially, starting from 1, as defined by ORDER BY) to each row within the partition."
supportsWindow = false,
pivotResult = true
public class GenericUDAFRowNumber extends AbstractGenericUDAFResolver
static final Log LOG = LogFactory.getLog(GenericUDAFRowNumber.class.getName());
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws SemanticException
if (parameters.length != 0)
throw new UDFArgumentTypeException(parameters.length - 1,
"No argument is expected.");
return new GenericUDAFRowNumberEvaluator();
static class RowNumberBuffer implements AggregationBuffer
ArrayList<IntWritable> rowNums;
int nextRow;
void init()
rowNums = new ArrayList<IntWritable>();
nextRow = 1;
void incr()
rowNums.add(new IntWritable(nextRow++));
public static class GenericUDAFRowNumberEvaluator extends
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException
super.init(m, parameters);
if (m != Mode.COMPLETE)
throw new HiveException("Only COMPLETE mode supported for row_number function");
return ObjectInspectorFactory.getStandardListObjectInspector(
public AggregationBuffer getNewAggregationBuffer() throws HiveException
return new RowNumberBuffer();
public void reset(AggregationBuffer agg) throws HiveException
((RowNumberBuffer) agg).init();
public void iterate(AggregationBuffer agg, Object[] parameters)
throws HiveException
((RowNumberBuffer) agg).incr();
public Object terminatePartial(AggregationBuffer agg)
throws HiveException
throw new HiveException("terminatePartial not supported");
public void merge(AggregationBuffer agg, Object partial)
throws HiveException
throw new HiveException("merge not supported");
public Object terminate(AggregationBuffer agg) throws HiveException
return ((RowNumberBuffer) agg).rowNums;
[mw_shl_code=sql,true]select s, sum(f) over (partition by i), row_number() over () from over10k where s = 'tom allen' or s = 'bob steinbeck';[/mw_shl_code]