topk in GPU

topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
Semantics:
Returns the :attr:k largest elements of the given :attr:input tensor, a long a given dimension.

"topk"

example:

1
2
3
4
5
6
>>input = torch.tensor([[-3., -1., -2., -8., -7., -4., -9., -6.],
[ 3., 1., 2., 8., 7., 4., 9., 6.]], dtype=float)
>>output_data,output_index = input.topk(5, 1, True, True)
>>print(output_data)
>>tensor([-1., -2., -3., -4., -5.],
[ 9., 8., 7., 5., 4.])

Select topk value – Radix select

Radix select is not a comparison select but a counting select algorithm. When we select n bit keys, 2n counts are prepared for each number.

Simple Example:

1
2
# Get top5 data
(0, 3, 2, 2, 3, 2, 0, 3, 2, 1)

Step0: count[0x11 & input] ++;
Step1: count[0x11] = 3, remain= 2
Step2: count[0x10] = 3, count > remain – found 5th topk value
Step3: has_topk = (input >= 5th topkvalue)

convert fp32 to uint32

Radix select assume all the data is unit32 type, so we need to convert float32 to unint32, double to uint64 first.
"convert"

"sign"

test file:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <stdio.h>
#include <stdint.h>

int main(){

float v = 1;
uint32_t x, mask, out;
x = *((uint32_t*)&v);
mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
out = x ^ mask;

v = 2.0;
x = *((uint32_t*)&v);
mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
out = x ^ mask;

v = -1.0;
x = *((uint32_t*)&v);
mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
out = x ^ mask;

v = -2.0;
x = *((uint32_t*)&v);
mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
out = x ^ mask;
return 0;
}

run:

1
2
3
4
5
6
7
8
9
$ g++ -g float.c -o float
(gdb) x/4b &mask
0x7fffffffdf70: 0x00 0x00 0x00 0x80
(gdb) x/4tb &x
0x7fffffffdf6c: 00000000 00000000 00000000 11000000
(gdb) x/4tb &mask
0x7fffffffdf70: 11111111 11111111 11111111 11111111
(gdb) x/4tb &out
0x7fffffffdf74: 11111111 11111111 11111111 00111111

"value"

"select"

Select topk value – Radix select

exclusive prefix scan
Now, we got topk th value: desired value got from previous step, but we don’t know at what index to write out the resulting values.
Inorder to get this, we performance an exclusive prefix sum of “hasTopk”, this will return the resulting index into which we need to write the result, if a thread has a result.

1
bool hasTopK = (input_value >= topKValue);

Store hasTopK into shared local memory: smem[thread_id] = hasTopK
"index"

"a"

1
2
3
if (hasTopk) {
output[index-1] = input_value;
}

Sort for the top k value

Bitonic sorter

Bitonic sort is a sorting algorithm designed specially for parallel machines.
sequence is called Bitonic if it is first increasing, then decreasing. In other words, an array arr[0..n-i] is Bitonic if there exists an index i where 0<=i<=n-1 such that

**x0 <= x1 …..<= xi and xi >= xi+1….. >= xn-1 **
given a bitonic sequence, if we apply recursively these operations we get a sorted sequence.

"sort1"

"sort2"

Bitonic sorter

Unsorted squence -> bitornic squence -> sorted sequence
"sort3"