Commit 8f24d345 by tianxing wang

finish lfucache

parent f081fb39
......@@ -25,7 +25,8 @@ message(STATUS "torch_include_dirs" ${TORCH_INCLUDE_DIRS})
find_package (Python3 COMPONENTS Interpreter Development REQUIRED)
set(cache_lib_name gpucache)
add_library(${cache_lib_name} SHARED ${SOURCE_FILES})
add_library(${cache_lib_name} SHARED ${SOURCE_FILES}
src/lfu_cache.h)
set_target_properties(${cache_lib_name} PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
CUDA_ARCHITECTURES "86"
......
#include <cuda_runtime.h>
#include <iostream>
/*
* different cuda device has different atomic function, so need to adapt them
......
......@@ -49,7 +49,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.def("NewLRUCache", &gpucache::lrucache::NewLRUCache, "create a lru cache",py::return_value_policy::reference);
/*-----------------------------------------------------------------------------------------------------------------------------------*/
/* lrucache */
/* fifocache */
/*-----------------------------------------------------------------------------------------------------------------------------------*/
py::class_<gpucache::fifocache::FIFOCacheWrapper> fifo_cache(m, "FIFOCache");
......
......@@ -180,6 +180,7 @@ namespace gpucache {
CUDA_CHECK(cudaMalloc((void **) &keys, capacity * sizeof(KeyType)));
CUDA_CHECK(cudaMalloc((void **) &values, capacity * valueSize));
CUDA_CHECK(cudaMalloc((void **) &timestamps, capacity * sizeof(uint8_t)));
CUDA_CHECK(cudaMemset(timestamps, 0, capacity * sizeof(uint8_t)));
CUDA_CHECK(cudaMalloc((void **) &bucketMutexes, nbucket * sizeof(WarpMutex)));
dim3 block(defaultBlockX);
dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX);
......
#include "common.cuh"
#include "lfu_cache.h"
namespace gpucache {
namespace lfucache{
const uint8_t maxTS = UINT8_MAX;
template<typename KeyType, typename ElemType>
struct BucketView {
__device__ BucketView(KeyType *k, ElemType *v, uint8_t *ts, WarpMutex *m,
uint32_t num_elems_per_value)
: mutex(m), bkeys(k), bvalues(v), btimestamps(ts),
num_elems_per_value(num_elems_per_value) {}
__device__ int Get(const ThreadCtx &ctx, const KeyType key) {
KeyType lane_key = bkeys[ctx.lane_id];
uint8_t ts = btimestamps[ctx.lane_id];
bool exist = (ts > 0 && lane_key == key);
uint32_t exist_mask = __ballot_sync(warpFullMask, exist);
int slot_num = __ffs(exist_mask) - 1;
if (exist_mask == 0) { // not found
return -1;
} else {
if (ctx.lane_id != slot_num && ts > 1) { // decrease the counter
btimestamps[ctx.lane_id] = ts - 1;
}
return slot_num; // return exist slot num
};
}
__device__ int TryPut(const ThreadCtx &ctx, const KeyType key) {
KeyType lane_key = bkeys[ctx.lane_id];
uint8_t ts = btimestamps[ctx.lane_id];
bool exist = (ts > 0 && lane_key == key);
uint32_t exist_mask = __ballot_sync(warpFullMask, exist);
int slot_num = __ffs(exist_mask) - 1;
if (exist_mask != 0) {
if (ctx.lane_id == slot_num) {
ts = maxTS;
}
__syncwarp();
}
if (slot_num == -1) { // key don't exist
uint32_t emptyMask = __ballot_sync(warpFullMask, ts != 0);
if (emptyMask != warpFullMask) { // bucket not full
slot_num = __popc(emptyMask);
if (ctx.lane_id == slot_num) {
bkeys[slot_num] = key;
ts = maxTS;
}
}
__syncwarp();
}
btimestamps[ctx.lane_id] = ts;
return slot_num;
}
__device__ int Evict(const ThreadCtx &ctx, const KeyType key,
KeyType *evict_key) {
KeyType lane_key = bkeys[ctx.lane_id];
uint8_t ts = btimestamps[ctx.lane_id];
uint32_t evict_mask = __ballot_sync(warpFullMask, ts == 1);
int slot_num = __ffs(evict_mask) - 1;
if(slot_num == -1){ // no ts equals 1, do warp reduction find min ts > 0
uint8_t min_ts = ts;
uint32_t min_ts_lane_id = ctx.lane_id;
uint32_t reduce_mask = __ballot_sync(warpFullMask, ts != 0);
for(int offset = 16; offset > 0; offset /= 2){
uint8_t temp_ts = __shfl_down_sync(reduce_mask,ts,offset);
if(temp_ts != 0 && temp_ts < min_ts){
min_ts = temp_ts;
min_ts_lane_id = ctx.lane_id + offset;
}
}
slot_num = __shfl_sync(warpFullMask, min_ts_lane_id, 0);
if(ctx.lane_id == 0){
printf("evict slot_num %d, ts = %u",slot_num,ts);
}
}
if(ctx.lane_id == slot_num) {
*evict_key = lane_key;
bkeys[slot_num] = key;
btimestamps[slot_num] = maxTS;
}
return slot_num;
}
__device__ void ReadOneValue(const ThreadCtx &ctx, uint8_t slot_num,
ElemType *out) {
for (size_t i = ctx.lane_id; i < num_elems_per_value; i += warpsize) {
out[i] = bvalues[slot_num * num_elems_per_value + i];
}
}
__device__ void WriteOneValue(const ThreadCtx &ctx, uint8_t slot_num,
ElemType *v) {
for (size_t i = ctx.lane_id; i < num_elems_per_value; i += warpsize) {
bvalues[slot_num * num_elems_per_value + i] = v[i];
}
}
WarpMutex *mutex;
KeyType *bkeys;
ElemType *bvalues;
uint8_t *btimestamps;
uint32_t num_elems_per_value;
};
template<typename KeyType, typename ElemType>
__device__ __host__ BucketView<KeyType, ElemType>
setBucketView(ThreadCtx ctx, KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t num_elem_per_value, uint32_t bucket_id);
template<typename KeyType, typename ElemType>
class LFUCache: public Cache<KeyType, ElemType> {
friend BucketView<KeyType, ElemType> __device__ __host__ setBucketView<KeyType, ElemType>(
ThreadCtx ctx, KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t num_elem_per_value, uint32_t bucket_id);
public:
explicit LFUCache(const CacheConfig& cfg);
~LFUCache();
uint32_t KeySize() override;
uint32_t ValueSize() override;
uint64_t Capacity() override;
uint32_t NumElemsPerValue() override;
uint32_t MaxQueryNum() override;
int8_t DeviceId() override;
uint32_t Dim() override;
CacheConfig::CacheEvictStrategy Strategy() override;
void Clear() override;
void Get(cudaStream_t stream, uint32_t num_query, KeyType *queries,
ElemType *results, bool *find_mask) override;
void Put(cudaStream_t stream, uint32_t num_query, KeyType *put_keys,
ElemType *put_values, uint32_t *n_evict = nullptr,
KeyType *evict_keys = nullptr) override;
private:
KeyType *keys;
ElemType *values;
uint8_t *timestamps{};
uint32_t nbucket; // 32 values for one bucket
void *bucketMutexes{};
// CacheConfig::CacheEvictStrategy strategy;
uint64_t capacity;
uint32_t keySize;
uint32_t valueSize;
uint32_t numElemPerValue; // embedding dim
int8_t deviceId;
uint32_t dim;
// store missing keys and indices for Evict
KeyType *queryKeyBuffer{};
uint32_t *queryIndiceBuffer{};
uint32_t maxQueryNum;
};
template<typename KeyType, typename ElemType>
LFUCache<KeyType, ElemType>::LFUCache(const CacheConfig &cfg)
: keySize(cfg.keySize), valueSize(cfg.valueSize), capacity(cfg.capacity),
maxQueryNum(cfg.maxQueryNum), deviceId(cfg.deviceId), dim(cfg.dim) {
// std::cout << "LFUCache constructor" << std::endl;
numElemPerValue = valueSize / sizeof(ElemType);
nbucket = (capacity + warpsize - 1) / warpsize;
printf("LFUCache: keySize: %lu, valueSize: %u, dim: %u, capacity: %lu, "
"maxQueryNum: %u, deviceId: %u\n",
sizeof(KeyType), valueSize, dim, capacity, maxQueryNum, deviceId);
CUDA_CHECK(cudaMalloc((void **) &keys, capacity * sizeof(KeyType)));
CUDA_CHECK(cudaMalloc((void **) &values, capacity * valueSize));
CUDA_CHECK(cudaMalloc((void **) &timestamps, capacity * sizeof(uint8_t)));
CUDA_CHECK(cudaMemset(timestamps, 0, capacity * sizeof(uint8_t)));
CUDA_CHECK(cudaMalloc((void **) &bucketMutexes, nbucket * sizeof(WarpMutex)));
dim3 block(defaultBlockX);
dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX);
initLocks<<<grid, block>>>(nbucket, bucketMutexes);
CUDA_CHECK(cudaMalloc((void **) &queryKeyBuffer, maxQueryNum * sizeof(KeyType)));
CUDA_CHECK(cudaMalloc((void **) &queryIndiceBuffer,
maxQueryNum * sizeof(uint32_t)));
}
template<typename KeyType, typename ElemType>
LFUCache<KeyType, ElemType>::~LFUCache() {
// TODO need to add CudaDeviceGuard
CUDA_CHECK(cudaFree(keys));
CUDA_CHECK(cudaFree(values));
CUDA_CHECK(cudaFree(timestamps));
CUDA_CHECK(cudaFree(bucketMutexes));
CUDA_CHECK(cudaFree(queryKeyBuffer));
CUDA_CHECK(cudaFree(queryIndiceBuffer));
}
template<typename KeyType, typename ElemType>
uint32_t LFUCache<KeyType, ElemType>::KeySize() {
return keySize;
}
template<typename KeyType, typename ElemType>
uint32_t LFUCache<KeyType, ElemType>::ValueSize() {
return valueSize;
}
template<typename KeyType, typename ElemType>
uint64_t LFUCache<KeyType, ElemType>::Capacity() {
return capacity;
}
template<typename KeyType, typename ElemType>
uint32_t LFUCache<KeyType, ElemType>::NumElemsPerValue() {
return numElemPerValue;
}
template<typename KeyType, typename ElemType>
uint32_t LFUCache<KeyType, ElemType>::MaxQueryNum() {
return maxQueryNum;
}
template<typename KeyType, typename ElemType>
int8_t LFUCache<KeyType, ElemType>::DeviceId() {
return deviceId;
}
template<typename KeyType, typename ElemType>
uint32_t LFUCache<KeyType, ElemType>::Dim() {
return dim;
}
template<typename KeyType, typename ElemType>
CacheConfig::CacheEvictStrategy LFUCache<KeyType, ElemType>::Strategy() {
return CacheConfig::LFU;
}
template<typename KeyType, typename ElemType>
void LFUCache<KeyType, ElemType>::Clear() {
CUDA_CHECK(cudaMemset(timestamps, 0, capacity * sizeof(uint8_t)));
dim3 block(defaultBlockX);
dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX);
initLocks<<<grid, block>>>(nbucket, bucketMutexes);
}
template<typename KeyType, typename ElemType>
__device__ __host__ BucketView<KeyType, ElemType>
setBucketView(ThreadCtx ctx, KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t num_elem_per_value, uint32_t bucket_id) {
return BucketView<KeyType, ElemType>(
cache_keys + bucket_id * warpsize,
cache_values + bucket_id * warpsize * num_elem_per_value,
cache_timestamps + bucket_id * warpsize,
reinterpret_cast<WarpMutex *>(cache_mutexes) + bucket_id,
num_elem_per_value);
}
template<typename KeyType, typename ElemType>
__global__ void GetInternal(KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t nbucket, uint32_t num_elem_per_value,
uint32_t num_query, KeyType *queries,
ElemType *results, bool *find_mask) {
ThreadCtx ctx{};
__shared__ KeyType blockQueryKeys[defaultNumWarpsPerBlock][warpsize];
__shared__ uint32_t blockBucketIds[defaultNumWarpsPerBlock][warpsize];
for (uint32_t offset = ctx.global_warp_idx * warpsize; offset < num_query;
offset += ctx.num_warps * warpsize) {
uint32_t n_query = min(warpsize, num_query - offset);
if (ctx.lane_id < n_query) {
uint32_t idx = offset + ctx.lane_id;
KeyType key = queries[idx];
size_t hash = getHash<KeyType>(key);
uint32_t bucket_id = hash % nbucket;
blockQueryKeys[ctx.block_warp_idx][ctx.lane_id] = key;
blockBucketIds[ctx.block_warp_idx][ctx.lane_id] = bucket_id;
}
__syncwarp();
// 32 threads compare it own slot with key
// if find write to result
for (uint32_t i = 0; i < n_query; i++) {
uint32_t idx = offset + i;
KeyType key = blockQueryKeys[ctx.block_warp_idx][i];
uint32_t bucket_id = blockBucketIds[ctx.block_warp_idx][i];
auto bucket = setBucketView<KeyType, ElemType>(
ctx, cache_keys, cache_values, cache_timestamps, cache_mutexes,
num_elem_per_value, bucket_id);
bucket.mutex->Lock(ctx, bucket_id);
int slot_num = bucket.Get(ctx, key);
if (slot_num != -1) {
find_mask[idx] = true;
bucket.ReadOneValue(ctx, slot_num, &results[idx * num_elem_per_value]);
}
bucket.mutex->UnLock(ctx);
}
__syncwarp();
}
}
template<typename KeyType, typename ElemType>
__global__ void
PutWithoutEvictInternal(KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t nbucket, uint32_t num_elem_per_value,
uint32_t num_query, KeyType *put_keys,
ElemType *put_values, uint32_t *n_missing,
KeyType *missing_keys, uint32_t *missing_indices) {
ThreadCtx ctx{};
__shared__ KeyType blockPutKeys[defaultNumWarpsPerBlock][warpsize];
__shared__ uint32_t blockBucketIds[defaultNumWarpsPerBlock][warpsize];
for (uint32_t offset = ctx.global_warp_idx * warpsize; offset < num_query;
offset += ctx.num_warps * warpsize) {
uint32_t n_query = min(warpsize, num_query - offset);
if (ctx.lane_id < n_query) {
uint32_t idx = offset + ctx.lane_id;
KeyType key = put_keys[idx];
size_t hash = getHash(key);
uint32_t bucket_id = hash % nbucket;
blockPutKeys[ctx.block_warp_idx][ctx.lane_id] = key;
blockBucketIds[ctx.block_warp_idx][ctx.lane_id] = bucket_id;
}
__syncwarp();
uint32_t n_warp_missing = 0;
KeyType warp_missing_key;
uint32_t warp_missing_index = 0;
// 32 threads handle n_query keys together instead of 1 thread for 1 key
for (uint32_t i = 0; i < n_query; i++) {
uint32_t idx = offset + i;
KeyType key = blockPutKeys[ctx.block_warp_idx][i];
// ElemType* Value = &put_values[idx];
uint32_t bucket_id = blockBucketIds[ctx.block_warp_idx][i];
auto bucket = setBucketView<KeyType, ElemType>(
ctx, cache_keys, cache_values, cache_timestamps, cache_mutexes,
num_elem_per_value, bucket_id);
bucket.mutex->Lock(ctx, bucket_id);
int slot_num = bucket.TryPut(ctx, key);
if (slot_num != -1) { // insert value
// if(ctx.lane_id == 0){
// printf("warp %u put key %u value %u\n",ctx.block_warp_idx, key,
// put_values[idx*num_elem_per_value]);
// }
bucket.WriteOneValue(ctx, slot_num,
&put_values[idx * num_elem_per_value]);
} else { // bucket full record missing_key + idx for Evict
if (ctx.lane_id ==
n_warp_missing) { // i-th thread keep i-th missing key
// printf("warp %u lane_id %u try put key %u fail\n",
// ctx.global_warp_idx, ctx.lane_id, key);
warp_missing_key = key;
warp_missing_index = idx;
}
n_warp_missing += 1;
}
bucket.mutex->UnLock(ctx);
}
// reduce missing_key & idx
if (n_warp_missing > 0) {
uint32_t base_missing_idx = 0;
if (ctx.lane_id == 0) {
base_missing_idx = atomicAdd(n_missing, n_warp_missing);
}
base_missing_idx = __shfl_sync(warpFullMask, base_missing_idx, 0);
if (ctx.lane_id < n_warp_missing) {
// printf("warp %u thread %u add missing_key %u index %u in idx
// %u\n",ctx.global_warp_idx, ctx.lane_id,warp_missing_key,
// warp_missing_index,base_missing_idx + ctx.lane_id);
missing_keys[base_missing_idx + ctx.lane_id] = warp_missing_key;
missing_indices[base_missing_idx + ctx.lane_id] = warp_missing_index;
}
}
__syncwarp();
}
}
template<typename KeyType, typename ElemType>
__global__ void
EvictInternal(KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes, uint32_t nbucket,
uint32_t num_elem_per_value, ElemType *put_values,
uint32_t n_missing, KeyType *missing_keys,
uint32_t *missing_indices, uint32_t *num_evict,
KeyType *evict_keys) {
ThreadCtx ctx{};
__shared__ KeyType blockPutKeys[defaultNumWarpsPerBlock][warpsize];
__shared__ uint32_t blockBucketIds[defaultNumWarpsPerBlock][warpsize];
for (uint32_t offset = ctx.global_warp_idx * warpsize; offset < n_missing;
offset += ctx.num_warps * warpsize) {
uint32_t n_evict = min(warpsize, n_missing - offset);
if (ctx.lane_id < n_evict) {
uint32_t idx = offset + ctx.lane_id;
KeyType key = missing_keys[idx];
size_t hash = getHash(missing_keys[idx]);
uint32_t bucket_id = hash % nbucket;
blockPutKeys[ctx.block_warp_idx][ctx.lane_id] = key;
blockBucketIds[ctx.block_warp_idx][ctx.lane_id] = bucket_id;
}
__syncwarp();
for (uint32_t i = 0; i < n_evict; i++) {
uint32_t idx = offset + i;
KeyType key = blockPutKeys[ctx.block_warp_idx][i];
uint32_t bucket_id = blockBucketIds[ctx.block_warp_idx][i];
auto bucket = setBucketView<KeyType, ElemType>(
ctx, cache_keys, cache_values, cache_timestamps, cache_mutexes,
num_elem_per_value, bucket_id);
bucket.mutex->Lock(ctx, bucket_id);
KeyType evict_key = 0;
int slot_num = bucket.Evict(ctx, key, &evict_key);
uint32_t evict_keys_idx = 0;
if (num_evict && ctx.lane_id == slot_num) {
evict_keys_idx = atomicAdd(num_evict, 1);
evict_keys[evict_keys_idx] = evict_key;
}
bucket.WriteOneValue(ctx, slot_num,
put_values +
missing_indices[idx] * num_elem_per_value);
bucket.mutex->UnLock(ctx);
}
}
}
template<typename KeyType, typename ElemType>
void
LFUCache<KeyType, ElemType>::Get(cudaStream_t stream, uint32_t num_query, KeyType *queries, ElemType *results,
bool *find_mask) {
assert(num_query <= maxQueryNum && "num_query should less than maxQueryNum");
if (num_query == 0) {
return;
}
dim3 block(defaultBlockX);
dim3 grid((num_query + defaultBlockX - 1) / defaultBlockX);
if (std::is_same<KeyType, int32_t>::value) {
//std::cout << "lfu_get_internal_int32" << "\n";
GetInternal<KeyType, ElemType><<<grid, block, uint32SharedMemorySize, stream>>>(
keys, values, timestamps, bucketMutexes, nbucket, numElemPerValue,
num_query, queries, results, find_mask);
} else if (std::is_same<KeyType, int64_t>::value) {
//std::cout << "lfu_get_internal_int64" << "\n";
GetInternal<KeyType, ElemType><<<grid, block, uint64SharedMemorySize, stream>>>(
keys, values, timestamps, bucketMutexes, nbucket, numElemPerValue,
num_query, queries, results, find_mask);
} else {
//std::cout << "unsupported keytype" << std::endl;
throw "Unsupported KeyType!";
}
}
template<typename KeyType, typename ElemType>
void LFUCache<KeyType, ElemType>::Put(cudaStream_t stream, uint32_t num_query, KeyType *put_keys,
ElemType *put_values, uint32_t *n_evict, KeyType *evict_keys) {
assert(num_query <= maxQueryNum);
if (num_query == 0) {
return;
}
dim3 block(defaultBlockX);
dim3 grid((num_query + defaultBlockX - 1) / defaultBlockX);
uint32_t *n_missing;
CUDA_CHECK(cudaMallocManaged(&n_missing, sizeof(uint32_t)));
CUDA_CHECK(cudaMemset(n_missing, 0, sizeof(uint32_t)));
if (std::is_same<KeyType, int32_t>::value || std::is_same<KeyType, uint32_t>::value) {
PutWithoutEvictInternal<KeyType, ElemType>
<<<grid, block, uint32SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, num_query, put_keys, put_values,
n_missing, queryKeyBuffer, queryIndiceBuffer);
CUDA_CHECK(cudaStreamSynchronize(stream));
EvictInternal<KeyType, ElemType>
<<<grid, block, uint32SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, put_values, *n_missing,
queryKeyBuffer,
queryIndiceBuffer, n_evict, evict_keys);
} else if (std::is_same<KeyType, int64_t>::value || std::is_same<KeyType, uint64_t>::value) {
PutWithoutEvictInternal<KeyType, ElemType>
<<<grid, block, uint64SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, num_query, put_keys, put_values,
n_missing, queryKeyBuffer, queryIndiceBuffer);
CUDA_CHECK(cudaStreamSynchronize(stream));
EvictInternal<KeyType, ElemType>
<<<grid, block, uint64SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, put_values, *n_missing,
queryKeyBuffer,
queryIndiceBuffer, n_evict, evict_keys);
} else {
// std::cout << "unsupported keytype" << std::endl;
throw "Unsupported KeyType";
}
}
}
}
// LFUCacheWrapper API
namespace gpucache {
namespace lfucache {
template<typename ElemType>
void NewInt32LFUCache(void **cache, CacheConfig cfg) {
auto c = new LFUCache<int32_t, ElemType>(cfg);
//std::cout << "NewInt32LFUCache LFUCache keysize " << c->keySize << std::endl;
*cache = reinterpret_cast<void *>(c);
}
template<typename ElemType>
void NewInt64LFUCache(void **cache, CacheConfig cfg) {
auto c = new LFUCache<int64_t, ElemType>(cfg);
*cache = reinterpret_cast<void *>(c);
}
LFUCacheWrapper::LFUCacheWrapper(torch::Tensor t, CacheConfig cfg) {
//std::cout << "LFUCacheWrapper constructor" << std::endl;
dtype = t.scalar_type();
cache_cfg = cfg;
key_is_int32 = cfg.capacity < INT32_MAX;
if (key_is_int32) {
kdtype = torch::kInt32;
AT_DISPATCH_ALL_TYPES(t.scalar_type(), "int32_lfu_cache", [&] {
NewInt32LFUCache<scalar_t>(&lfu_cache, cfg);
});
} else {
kdtype = torch::kInt64;
AT_DISPATCH_ALL_TYPES(t.scalar_type(), "int64_lfu_cache", [&] {
NewInt64LFUCache<scalar_t>(&lfu_cache, cfg);
});
}
// AT_DISPATCH_ALL_TYPES(dtype,"lfu_test",[&]{
// auto c = reinterpret_cast<LFUCache<int32_t,scalar_t>*>(lfu_cache);
// std::cout << "LFUCacheWrapper constructor LFUCache dim: " << c->Dim() << " keySize: " << c->KeySize() << "\n";
// });
// std::cout << "LFUCacheWrapper constructor dtype: " << dtype << " kdtype: " << kdtype << "\n";
}
LFUCacheWrapper::~LFUCacheWrapper() {
if (key_is_int32) {
AT_DISPATCH_ALL_TYPES(dtype, "int32_lfu_destructor", [&] {
auto cache = reinterpret_cast<LFUCache<int32_t, scalar_t> *>(lfu_cache);
delete cache;
});
} else {
AT_DISPATCH_ALL_TYPES(dtype, "int64_lfu_destructor", [&] {
auto cache = reinterpret_cast<LFUCache<int64_t, scalar_t> *>(lfu_cache);
delete cache;
});
}
//std::cout << "~LFUCacheWrapper()" << std::endl;
}
std::pair<torch::Tensor, torch::Tensor> LFUCacheWrapper::Get(uint32_t num_query, const torch::Tensor queries) {
assert(queries.device().index() == cache_cfg.deviceId &&
"query keys and cache should be in the same device");
assert(queries.size(0) == num_query && "num_query should equal queries.size(0)");
assert(num_query <= cache_cfg.maxQueryNum && "num_query must less than max_query_num");
assert(queries.dtype().toScalarType() == kdtype && "key type doesn't match");
//std::cout << "Get queries: " << queries << std::endl;
// printf("hello!!!\n");
//std::cout << "Get num_query: " << num_query <<"\n";
//std::cout << "Get queries.sizes() " << queries.sizes() << "\n";
auto keys = queries.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream(cache_cfg.deviceId).stream();
auto result = torch::empty({num_query, cache_cfg.dim},
torch::dtype(dtype).device(torch::kCUDA, cache_cfg.deviceId));
auto find_mask = torch::empty({num_query},
torch::dtype(torch::kBool).device(torch::kCUDA, cache_cfg.deviceId));
if (key_is_int32) {
AT_DISPATCH_ALL_TYPES(dtype, "int32_lfu_get", [&] {
LFUCache<int32_t, scalar_t> *cache = reinterpret_cast<LFUCache<int32_t, scalar_t> *>(lfu_cache);
// cache->Test();
//std::cout << "keysize: " << cache->KeySize() << "\n";
//std::cout << "dim: " << cache->Dim() << "\n";
cache->Get(stream, num_query, reinterpret_cast<int32_t *>(keys),
reinterpret_cast<scalar_t *>(result.data_ptr()),
reinterpret_cast<bool *>(find_mask.data_ptr()));
});
} else {
//std::cout << "int64_lfu_get" << "\n";
AT_DISPATCH_ALL_TYPES(dtype, "int64_lfu_get", [&] {
LFUCache<int64_t, scalar_t> *cache = reinterpret_cast<LFUCache<int64_t, scalar_t> *>(lfu_cache);
//std::cout << "cast success!" << std::endl;
cache->Get(stream, num_query, reinterpret_cast<int64_t *>(keys),
reinterpret_cast<scalar_t *>(result.data_ptr()),
reinterpret_cast<bool *>(find_mask.data_ptr()));
});
}
return std::make_pair(result, find_mask);
}
void LFUCacheWrapper::Put(uint32_t num_query, const torch::Tensor keys, const torch::Tensor values) {
assert(keys.device().index() == cache_cfg.deviceId && cache_cfg.deviceId == values.device().index() &&
"put keys, put values and cache should be in the same device");
assert(num_query == keys.size(0) && keys.size(0) == values.size(0) &&
"dim 0 of keys,values and num_query should be equal");
assert(num_query <= cache_cfg.maxQueryNum && "num_query must less than max_query_num");
assert(values.dtype() == dtype);
auto stream = at::cuda::getCurrentCUDAStream(cache_cfg.deviceId).stream();
if (key_is_int32) {
//std::cout << "int32_lfu_put" << "\n";
AT_DISPATCH_ALL_TYPES(dtype, "int32_lfu_put", [&] {
auto cache = reinterpret_cast<LFUCache<int32_t, scalar_t> *>(lfu_cache);
//std::cout << "put_lfu_dim" << cache->Dim() << std::endl;
cache->Put(stream, num_query, reinterpret_cast<int32_t *>(keys.data_ptr()),
reinterpret_cast<scalar_t *>(values.data_ptr()));
});
} else {
//std::cout << "int64_lfu_put" << "\n";
AT_DISPATCH_ALL_TYPES(dtype, "int64_lfu_put", [&] {
auto cache = reinterpret_cast<LFUCache<int64_t, scalar_t> *>(lfu_cache);
cache->Put(stream, num_query, reinterpret_cast<int64_t *>(keys.data_ptr()),
reinterpret_cast<scalar_t *>(values.data_ptr()));
});
}
}
CacheConfig::CacheEvictStrategy LFUCacheWrapper::Strategy() {
return CacheConfig::CacheEvictStrategy::LFU;
}
uint64_t LFUCacheWrapper::Capacity() { return cache_cfg.capacity; }
uint32_t LFUCacheWrapper::KeySize() { return cache_cfg.keySize; }
uint32_t LFUCacheWrapper::ValueSize() { return cache_cfg.valueSize; }
uint32_t LFUCacheWrapper::MaxQueryNum() { return cache_cfg.maxQueryNum; }
uint64_t LFUCacheWrapper::DeviceId() { return cache_cfg.deviceId; }
uint32_t LFUCacheWrapper::Dim() { return cache_cfg.dim; }
void LFUCacheWrapper::Clear() {
if (key_is_int32) {
AT_DISPATCH_ALL_TYPES(dtype, "int32_lfu_clear", [&] {
auto cache = reinterpret_cast<LFUCache<int32_t, scalar_t> *>(lfu_cache);
cache->Clear();
});
} else {
AT_DISPATCH_ALL_TYPES(dtype, "int64_lfu_clear", [&] {
auto cache = reinterpret_cast<LFUCache<int64_t, scalar_t> *>(lfu_cache);
cache->Clear();
});
}
}
std::unique_ptr<LFUCacheWrapper> NewLFUCache(at::Tensor t, CacheConfig cfg) {
return std::make_unique<LFUCacheWrapper>(t, cfg);
}
} // namespace lfucache
} // namespace gpucache
#pragma once
#include <torch/extension.h>
#include "cache.h"
namespace gpucache {
namespace lfucache {
class LFUCacheWrapper {
public:
LFUCacheWrapper(at::Tensor t, CacheConfig cfg);
~LFUCacheWrapper();
std::pair<torch::Tensor, torch::Tensor> Get(uint32_t num_query, const torch::Tensor queries);
void Put(uint32_t num_query, const torch::Tensor keys, const torch::Tensor values);
CacheConfig::CacheEvictStrategy Strategy();
uint64_t Capacity();
uint32_t KeySize();
uint32_t ValueSize();
uint32_t MaxQueryNum();
uint64_t DeviceId();
uint32_t Dim();
void Clear();
private:
void *lfu_cache;
c10::ScalarType dtype; // value dtype
c10::ScalarType kdtype; // key dtype
bool key_is_int32; // only support int32 and int64
CacheConfig cache_cfg;
};
std::unique_ptr<LFUCacheWrapper> NewLFUCache(at::Tensor t, CacheConfig cfg);
} // namespace lfucache
} // namespace gpucache
\ No newline at end of file
......@@ -165,7 +165,7 @@ namespace gpucache {
uint32_t keySize;
uint32_t valueSize;
uint32_t numElemPerValue; // embedding dim
int8_t device_id;
int8_t deviceId;
uint32_t dim;
// store missing keys and indices for Evict
......@@ -178,17 +178,18 @@ namespace gpucache {
template<typename KeyType, typename ElemType>
LRUCache<KeyType, ElemType>::LRUCache(const CacheConfig &cfg)
: keySize(cfg.keySize), valueSize(cfg.valueSize), capacity(cfg.capacity),
maxQueryNum(cfg.maxQueryNum), device_id(cfg.deviceId), dim(cfg.dim) {
maxQueryNum(cfg.maxQueryNum), deviceId(cfg.deviceId), dim(cfg.dim) {
// std::cout << "LRUCache constructor" << std::endl;
numElemPerValue = valueSize / sizeof(ElemType);
nbucket = (capacity + warpsize - 1) / warpsize;
printf("LRUCache: keySize: %lu, valueSize: %u, dim: %u, capacity: %lu, "
"maxQueryNum: %u, deviceId: %u\n",
sizeof(KeyType), valueSize, dim, capacity, maxQueryNum, device_id);
sizeof(KeyType), valueSize, dim, capacity, maxQueryNum, deviceId);
CUDA_CHECK(cudaMalloc((void **) &keys, capacity * sizeof(KeyType)));
CUDA_CHECK(cudaMalloc((void **) &values, capacity * valueSize));
CUDA_CHECK(cudaMalloc((void **) &timestamps, capacity * sizeof(uint8_t)));
CUDA_CHECK(cudaMemset(timestamps, 0, capacity * sizeof(uint8_t)));
CUDA_CHECK(cudaMalloc((void **) &bucketMutexes, nbucket * sizeof(WarpMutex)));
dim3 block(defaultBlockX);
dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX);
......@@ -200,7 +201,7 @@ namespace gpucache {
template<typename KeyType, typename ElemType>
void LRUCache<KeyType, ElemType>::Clear() {
CUDA_CHECK(cudaMemset(keys, 0, capacity * sizeof(KeyType)));
// CUDA_CHECK(cudaMemset(keys, 0, capacity * sizeof(KeyType)));
CUDA_CHECK(cudaMemset(timestamps, 0, capacity * sizeof(uint8_t)));
dim3 block(defaultBlockX);
dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX);
......@@ -235,7 +236,7 @@ namespace gpucache {
uint32_t LRUCache<KeyType, ElemType>::MaxQueryNum() { return maxQueryNum; }
template<typename KeyType, typename ElemType>
int8_t LRUCache<KeyType, ElemType>::DeviceId() { return device_id; }
int8_t LRUCache<KeyType, ElemType>::DeviceId() { return deviceId; }
template<typename KeyType, typename ElemType>
uint32_t LRUCache<KeyType, ElemType>::Dim() { return dim; }
......@@ -282,8 +283,8 @@ namespace gpucache {
}
__syncwarp();
// 32 threads compare it own slot with key
// if find parallel write to result
// 32 threads compare it own slot with key
// if find write to result
for (uint32_t i = 0; i < n_query; i++) {
uint32_t idx = offset + i;
KeyType key = blockQueryKeys[ctx.block_warp_idx][i];
......@@ -434,7 +435,7 @@ namespace gpucache {
}
}
// TODO async CudaMemcpy
// TODO async CudaMemcpy
template<typename KeyType, typename ElemType>
void LRUCache<KeyType, ElemType>::Put(cudaStream_t stream, uint32_t num_query,
KeyType *put_keys,
......
......@@ -39,18 +39,18 @@ namespace gpucache {
}
}
__global__ void checkLocks(uint32_t n_bucket, void *bucketMutexes) {
uint32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
// printf("thread %u CUDA_CHECK lock\n",global_thread_idx);
if (global_thread_idx < n_bucket) {
auto mutex =
reinterpret_cast<WarpMutex *>(bucketMutexes) + global_thread_idx;
if (mutex->flag != 0u && mutex->flag != 1u) {
printf("bucket id %u not equal 0 or 1, is %u\n", global_thread_idx,
mutex->flag);
}
}
}
// __global__ void checkLocks(uint32_t n_bucket, void *bucketMutexes) {
// uint32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
// // printf("thread %u CUDA_CHECK lock\n",global_thread_idx);
// if (global_thread_idx < n_bucket) {
// auto mutex =
// reinterpret_cast<WarpMutex *>(bucketMutexes) + global_thread_idx;
// if (mutex->flag != 0u && mutex->flag != 1u) {
// printf("bucket id %u not equal 0 or 1, is %u\n", global_thread_idx,
// mutex->flag);
// }
// }
// }
}
......
......@@ -45,7 +45,7 @@ namespace gpucache{
__global__ void initLocks(uint32_t n_bucket, void *bucketMutexes);
__global__ void checkLocks(uint32_t n_bucket, void *bucketMutexes);
// __global__ void checkLocks(uint32_t n_bucket, void *bucketMutexes);
......
......@@ -18,7 +18,8 @@
"source": [
"# args: strtegy, capacity, keySize(only support int32 and int64), valueSize(dim * sizeof(elem), elem is decided by the passing tensor dtype of NewLRUCache), deviceId, dim\n",
"# cfg = libgpucache.CacheConfig(libgpucache.CacheConfig.LRU,65536,4,128,4096,0,32)\n",
"cfg = libgpucache.CacheConfig(libgpucache.CacheConfig.FIFO,65536,4,128,4096,0,32)"
"# cfg = libgpucache.CacheConfig(libgpucache.CacheConfig.FIFO,65536,4,128,4096,0,32)\n",
"cfg = libgpucache.CacheConfig(libgpucache.CacheConfig.LFU,65536,4,128,4096,0,32)"
]
},
{
......@@ -29,7 +30,7 @@
{
"data": {
"text/plain": [
"(<CacheEvictStrategy.FIFO: 0>, 65536, 4, 128, 4096, 0, 32)"
"(<CacheEvictStrategy.LFU: 2>, 65536, 4, 128, 4096, 0, 32)"
]
},
"execution_count": 3,
......@@ -132,7 +133,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"outputs": [
{
......@@ -151,7 +152,7 @@
" torch.Size([3, 32]))"
]
},
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
......
......@@ -6,6 +6,8 @@
#include "../src/lru_cache.cu"
#include "../src/fifo_cache.h"
#include "../src/fifo_cache.cu"
#include "../src/lfu_cache.h"
#include "../src/lfu_cache.cu"
#include <cuda_runtime.h>
#include <unordered_set>
#include <vector>
......@@ -226,6 +228,22 @@ namespace gpucache{
}
TEST(GPUCACHE,LFUCACHE){
CacheConfig cfg{};
cfg.strategy = CacheConfig::CacheEvictStrategy::LFU;
cfg.valueSize = 32;
cfg.capacity = 4096 * 2;
cfg.keySize = 4;
cfg.maxQueryNum = 2048;
cfg.deviceId = 0;
cfg.dim = 8;
lfucache::LFUCache<int32_t,uint32_t> cache(cfg);
TestCache(cache,cfg.dim);
}
TEST(GPUCACHE, FIFOCACHEWRAPPER){
CacheConfig cfg{CacheConfig::CacheEvictStrategy::FIFO,8192,4,32,2048,0,8};
auto t = torch::empty({1},torch::dtype(torch::kInt32).device(torch::kCUDA,0));
......@@ -243,8 +261,6 @@ namespace gpucache{
find_mask = result.second;
std::cout << " values: " << values << "find_mask: " << find_mask << std::endl;
// auto fifoc = reinterpret_cast<fifocache::FIFOCache<int32_t,int32_t>*>(cache->fifo_cache);
// std::cout <<"test keysize" << fifoc->KeySize() << std::endl;
CUDA_CHECK(cudaDeviceSynchronize());
}
......@@ -269,5 +285,25 @@ namespace gpucache{
// std::cout <<"test keysize" << lruc->KeySize() << std::endl;
CUDA_CHECK(cudaDeviceSynchronize());
}
TEST(GPUCACHE, LFUCACHEWRAPPER){
CacheConfig cfg{CacheConfig::CacheEvictStrategy::LFU,8192,4,32,2048,0,8};
auto t = torch::empty({1},torch::dtype(torch::kInt32).device(torch::kCUDA,0));
auto cache = lfucache::NewLFUCache(t,cfg);
auto keys = torch::arange(0,5,torch::dtype(torch::kInt32)).to(torch::kCUDA, 0);
auto [values, find_mask] = cache->Get(5,keys);
torch::Tensor put_values = torch::reshape(torch::arange(0, 5 * 8,torch::dtype(torch::kInt32)),{5,8}).to(torch::kCUDA, 0);
std::cout << "put_values: " << put_values << std::endl;
cache->Put(5,keys,put_values);
auto result = cache->Get(5,keys);
values = result.first;
find_mask = result.second;
std::cout << " values: " << values << "find_mask: " << find_mask << std::endl;
CUDA_CHECK(cudaDeviceSynchronize());
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment