Commit f081fb39 by tianxing wang

finish fifo cache

parent 52df1e52
......@@ -10,7 +10,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS on)
# gpucache
file(GLOB SOURCE_FILES
# ${CMAKE_CURRENT_SOURCE_DIR}/src/cuda/*
# ${CMAKE_CURRENT_SOURCE_DIR}/src/hash/*.cuh
${CMAKE_CURRENT_SOURCE_DIR}/src/hash/*.cu
${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/*.cu
)
......@@ -35,8 +35,9 @@ target_compile_options(${cache_lib_name} PRIVATE ${TORCH_CXX_FLAGS} -O3)
target_link_libraries(${cache_lib_name} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${cache_lib_name} PRIVATE -DTORCH_EXTENSION_NAME=lib${cache_lib_name})
find_library(TORCH_PYTHON_LIBRARY torch_python "${TORCH_INSTALL_PREFIX}/lib" REQUIRED)
find_library(TORCH_PYTHON_LIBRARY torch_python "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${cache_lib_name} PRIVATE ${TORCH_PYTHON_LIBRARY})
message(STATUS "TORCH_PYTHON_LIBRARY: " ${TORCH_PYTHON_LIBRARY} )
......
No preview for this file type
......@@ -2,7 +2,7 @@ TODO
[] LRU
[] FIFO
[] FIFO
[] LFU
......
#pragma once
#include <cstdint>
namespace gpucache {
struct CacheConfig {
......@@ -48,9 +48,6 @@ namespace gpucache {
virtual void Put(cudaStream_t stream, uint32_t num_keys, KeyType *keys, ElemType *values,uint32_t *n_evict, KeyType* evict_keys) = 0;
// virtual void* Mutex() = 0;
virtual void Clear() = 0;
virtual uint32_t MaxQueryNum() = 0;
......
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <cstdint>
#include <assert.h>
#include <stdio.h>
#include <cuda_runtime.h>
#include <vector>
#include <memory>
#include "utils.cuh"
#include "cache.h"
#include "hash/hash_function.cuh"
#define CUDA_CHECK(call) \
#define CUDA_CHECK(call) \
{ \
const cudaError_t error = call; \
if (error != cudaSuccess) \
......@@ -20,3 +22,5 @@
cudaGetErrorString(error)); \
} \
}
#include "common.cuh"
#include "lru_cache.h"
#include <pybind11/pybind11.h>
namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
py::class_<gpucache::CacheConfig> cfg(m,"CacheConfig");
cfg.def(py::init<>())
.def(py::init<gpucache::CacheConfig::CacheEvictStrategy, uint64_t, uint32_t, uint32_t, uint32_t, int8_t,uint32_t>())
.def_readwrite("strategy", &gpucache::CacheConfig::strategy)
.def_readwrite("capacity",&gpucache::CacheConfig::capacity)
.def_readwrite("key_size", &gpucache::CacheConfig::keySize)
.def_readwrite("value_size",&gpucache::CacheConfig::valueSize)
.def_readwrite("max_query_num", &gpucache::CacheConfig::maxQueryNum)
.def_readwrite("device_id",&gpucache::CacheConfig::deviceId)
.def_readwrite("dim",&gpucache::CacheConfig::dim);
py::enum_<gpucache::CacheConfig::CacheEvictStrategy>(cfg,"CacheEvictStrategy")
.value("LRU",gpucache::CacheConfig::CacheEvictStrategy::LRU)
.value("LFU",gpucache::CacheConfig::CacheEvictStrategy::LFU)
.value("FIFO",gpucache::CacheConfig::CacheEvictStrategy::FIFO)
.export_values();
py::class_<gpucache::LRUCacheWrapper> lru_cache(m, "LRUCache");
lru_cache
.def(py::init<at::Tensor, gpucache::CacheConfig>())
.def("Get",&gpucache::LRUCacheWrapper::Get,"get values for keys, find_mask return whether each key exists in cache")
.def("Put",&gpucache::LRUCacheWrapper::Put,"put key-value pairs")
.def("Strategy",&gpucache::LRUCacheWrapper::Strategy,"get evict strategy")
.def("Capacity",&gpucache::LRUCacheWrapper::Capacity,"return cache capacity")
.def("KeySize",&gpucache::LRUCacheWrapper::KeySize,"return key size")
.def("ValueSize",&gpucache::LRUCacheWrapper::ValueSize,"return value size")
.def("MaxQueryNum",&gpucache::LRUCacheWrapper::MaxQueryNum,"return max number of keys to get or key-values to put once")
.def("Clear",&gpucache::LRUCacheWrapper::Clear,"clear cache")
.def("Device",&gpucache::LRUCacheWrapper::DeviceId,"return device id")
.def("Dim",&gpucache::LRUCacheWrapper::Dim,"return value dim");
m.def("NewLRUCache", &gpucache::NewLRUCache, "create a lru cache",py::return_value_policy::reference);
}
#include "lru_cache.h"
#include "fifo_cache.h"
#include <pybind11/pybind11.h>
namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
/* CacheConfig */
/*-----------------------------------------------------------------------------------------------------------------------------------*/
py::class_<gpucache::CacheConfig> cfg(m,"CacheConfig");
cfg.def(py::init<>())
.def(py::init<gpucache::CacheConfig::CacheEvictStrategy, uint64_t, uint32_t, uint32_t, uint32_t, int8_t,uint32_t>())
.def_readwrite("strategy", &gpucache::CacheConfig::strategy)
.def_readwrite("capacity",&gpucache::CacheConfig::capacity)
.def_readwrite("keySize", &gpucache::CacheConfig::keySize)
.def_readwrite("valueSize",&gpucache::CacheConfig::valueSize)
.def_readwrite("maxQueryNum", &gpucache::CacheConfig::maxQueryNum)
.def_readwrite("deviceId",&gpucache::CacheConfig::deviceId)
.def_readwrite("dim",&gpucache::CacheConfig::dim);
py::enum_<gpucache::CacheConfig::CacheEvictStrategy>(cfg,"CacheEvictStrategy")
.value("LRU",gpucache::CacheConfig::CacheEvictStrategy::LRU)
.value("LFU",gpucache::CacheConfig::CacheEvictStrategy::LFU)
.value("FIFO",gpucache::CacheConfig::CacheEvictStrategy::FIFO)
.export_values();
/*-----------------------------------------------------------------------------------------------------------------------------------*/
/* lrucache */
/*-----------------------------------------------------------------------------------------------------------------------------------*/
py::class_<gpucache::lrucache::LRUCacheWrapper> lru_cache(m, "LRUCache");
lru_cache
.def(py::init<at::Tensor, gpucache::CacheConfig>())
.def("Get",&gpucache::lrucache::LRUCacheWrapper::Get,"get values for keys, find_mask return whether each key exists in cache")
.def("Put",&gpucache::lrucache::LRUCacheWrapper::Put,"put key-value pairs")
.def("Strategy",&gpucache::lrucache::LRUCacheWrapper::Strategy,"get evict strategy")
.def("Capacity",&gpucache::lrucache::LRUCacheWrapper::Capacity,"return cache capacity")
.def("KeySize",&gpucache::lrucache::LRUCacheWrapper::KeySize,"return key size")
.def("ValueSize",&gpucache::lrucache::LRUCacheWrapper::ValueSize,"return value size")
.def("MaxQueryNum",&gpucache::lrucache::LRUCacheWrapper::MaxQueryNum,"return max number of keys to get or key-values to put once")
.def("Clear",&gpucache::lrucache::LRUCacheWrapper::Clear,"clear cache")
.def("Device",&gpucache::lrucache::LRUCacheWrapper::DeviceId,"return device id")
.def("Dim",&gpucache::lrucache::LRUCacheWrapper::Dim,"return value dim");
m.def("NewLRUCache", &gpucache::lrucache::NewLRUCache, "create a lru cache",py::return_value_policy::reference);
/*-----------------------------------------------------------------------------------------------------------------------------------*/
/* lrucache */
/*-----------------------------------------------------------------------------------------------------------------------------------*/
py::class_<gpucache::fifocache::FIFOCacheWrapper> fifo_cache(m, "FIFOCache");
fifo_cache
.def(py::init<at::Tensor, gpucache::CacheConfig>())
.def("Get",&gpucache::fifocache::FIFOCacheWrapper::Get,"get values for keys, find_mask return whether each key exists in cache")
.def("Put",&gpucache::fifocache::FIFOCacheWrapper::Put,"put key-value pairs")
.def("Strategy",&gpucache::fifocache::FIFOCacheWrapper::Strategy,"get evict strategy")
.def("Capacity",&gpucache::fifocache::FIFOCacheWrapper::Capacity,"return cache capacity")
.def("KeySize",&gpucache::fifocache::FIFOCacheWrapper::KeySize,"return key size")
.def("ValueSize",&gpucache::fifocache::FIFOCacheWrapper::ValueSize,"return value size")
.def("MaxQueryNum",&gpucache::fifocache::FIFOCacheWrapper::MaxQueryNum,"return max number of keys to get or key-values to put once")
.def("Clear",&gpucache::fifocache::FIFOCacheWrapper::Clear,"clear cache")
.def("Device",&gpucache::fifocache::FIFOCacheWrapper::DeviceId,"return device id")
.def("Dim",&gpucache::fifocache::FIFOCacheWrapper::Dim,"return value dim");
m.def("NewFIFOCache", &gpucache::fifocache::NewFIFOCache, "create a fifo cache",py::return_value_policy::reference);
/*-----------------------------------------------------------------------------------------------------------------------------------*/
}
#include "common.cuh"
#include "fifo_cache.h"
namespace gpucache {
namespace fifocache {
template<typename KeyType, typename ElemType>
struct BucketView {
__device__ BucketView(KeyType *k, ElemType *v, uint8_t *ts, WarpMutex *m, uint32_t num_elems_per_value)
: mutex(m), bkeys(k), bvalues(v), btimestamps(ts), num_elems_per_value(num_elems_per_value) {}
__device__ int Get(const ThreadCtx &ctx, const KeyType key) {
KeyType lane_key = bkeys[ctx.lane_id];
KeyType ts = btimestamps[ctx.lane_id];
bool exist = (ts > 0 && lane_key == key);
uint32_t exist_mask = __ballot_sync(warpFullMask, exist);
int slot_num = __ffs(exist_mask) - 1; // if not exist slot_num is -1
return slot_num;
}
__device__ int TryPut(const ThreadCtx &ctx, const KeyType key) {
KeyType lane_key = bkeys[ctx.lane_id];
KeyType ts = btimestamps[ctx.lane_id];
bool exist = (ts > 0 && lane_key == key);
uint32_t exist_mask = __ballot_sync(warpFullMask, exist);
int slot_num = __ffs(exist_mask) - 1;
if (slot_num == -1) { // key doesn't exist
// if(ctx.lane_id == 0){
// printf("key %u doesn't exist\n",key);
// }
uint32_t empty_mask = __ballot_sync(warpFullMask, ts != 0);
// if(ctx.lane_id == 0){
// printf("try put key %u, empty_mask %u\n",key,empty_mask);
// }
if (empty_mask != warpFullMask) { // bucket not full
slot_num = __popc(empty_mask);
if (ts > 0) {
ts++;
} else if (ctx.lane_id == slot_num) {
bkeys[slot_num] = key;
ts = 1;
}
}
}
btimestamps[ctx.lane_id] = ts;
//printf("thread %u set btimestamps[%u] = %u\n",ctx.lane_id,ctx.lane_id,ts);
return slot_num;
}
__device__ int Evict(const ThreadCtx &ctx, const KeyType key, KeyType *evict_key) {
KeyType lane_key = bkeys[ctx.lane_id];
uint8_t ts = btimestamps[ctx.lane_id];
uint32_t evict_mask = __ballot_sync(warpFullMask, ts == 32);
int slot_num = __ffs(evict_mask) - 1;
if (slot_num == ctx.lane_id) {
ts = 1;
*evict_key = lane_key;
bkeys[ctx.lane_id] = key;
} else {
ts++;
}
btimestamps[ctx.lane_id] = ts;
return slot_num;
}
__device__ void ReadOneValue(const ThreadCtx &ctx, uint8_t slot_num,
ElemType *out) {
for (size_t i = ctx.lane_id; i < num_elems_per_value; i += warpsize) {
out[i] = bvalues[slot_num * num_elems_per_value + i];
}
}
__device__ void WriteOneValue(const ThreadCtx &ctx, uint8_t slot_num,
ElemType *v) {
for (size_t i = ctx.lane_id; i < num_elems_per_value; i += warpsize) {
bvalues[slot_num * num_elems_per_value + i] = v[i];
}
}
WarpMutex *mutex;
KeyType *bkeys;
ElemType *bvalues;
uint8_t *btimestamps;
uint32_t num_elems_per_value;
};
template<typename KeyType, typename ElemType>
__device__ __host__ BucketView<KeyType, ElemType>
setBucketView(ThreadCtx ctx, KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t num_elem_per_value, uint32_t bucket_id) {
return BucketView<KeyType, ElemType>(
cache_keys + bucket_id * warpsize,
cache_values + bucket_id * warpsize * num_elem_per_value,
cache_timestamps + bucket_id * warpsize,
reinterpret_cast<WarpMutex *>(cache_mutexes) + bucket_id,
num_elem_per_value);
}
template<typename KeyType, typename ElemType>
class FIFOCache : public Cache<KeyType, ElemType> {
public:
explicit FIFOCache(CacheConfig &cfg);
~FIFOCache();
uint32_t KeySize() override {
return keySize;
}
uint32_t ValueSize() override {
return valueSize;
}
CacheConfig::CacheEvictStrategy Strategy() override {
return CacheConfig::FIFO;
}
uint64_t Capacity() override {
return capacity;
}
uint32_t NumElemsPerValue() override {
return numElemPerValue;
}
int8_t DeviceId() override {
return deviceId;
}
uint32_t Dim() override {
return dim;
}
uint32_t MaxQueryNum() override {
return maxQueryNum;
}
void Get(cudaStream_t stream, uint32_t num_queries, KeyType *keys, ElemType *values, bool *find_mask) override;
void Put(cudaStream_t stream, uint32_t num_queries, KeyType *keys, ElemType *values, uint32_t *n_evict= nullptr,
KeyType *evict_keys=nullptr) override;
void Clear() override;
private:
KeyType *keys;
ElemType *values;
uint8_t *timestamps;
uint32_t nbucket; // 32 values for one bucket
void *bucketMutexes{};
// CacheConfig::CacheEvictStrategy strategy;
uint64_t capacity;
uint32_t keySize;
uint32_t valueSize;
uint32_t numElemPerValue; // embedding dim
int8_t deviceId;
uint32_t dim;
// store missing keys and indices for Evict
KeyType *queryKeyBuffer{};
uint32_t *queryIndiceBuffer{};
uint32_t maxQueryNum;
};
template<typename KeyType, typename ElemType>
FIFOCache<KeyType, ElemType>::FIFOCache(CacheConfig &cfg):keySize(cfg.keySize), valueSize(cfg.valueSize),
capacity(cfg.capacity),
maxQueryNum(cfg.maxQueryNum),
deviceId(cfg.deviceId), dim(cfg.dim) {
numElemPerValue = valueSize / sizeof(ElemType);
nbucket = (capacity + warpsize - 1) / warpsize;
printf("FIFOCache: keySize: %lu, valueSize: %u, dim: %u, capacity: %lu, "
"maxQueryNum: %u, deviceId: %u\n",
sizeof(KeyType), valueSize, dim, capacity, maxQueryNum, deviceId);
CUDA_CHECK(cudaMalloc((void **) &keys, capacity * sizeof(KeyType)));
CUDA_CHECK(cudaMalloc((void **) &values, capacity * valueSize));
CUDA_CHECK(cudaMalloc((void **) &timestamps, capacity * sizeof(uint8_t)));
CUDA_CHECK(cudaMalloc((void **) &bucketMutexes, nbucket * sizeof(WarpMutex)));
dim3 block(defaultBlockX);
dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX);
initLocks<<<grid, block>>>(nbucket, bucketMutexes);
CUDA_CHECK(cudaMalloc((void **) &queryKeyBuffer, maxQueryNum * sizeof(KeyType)));
CUDA_CHECK(cudaMalloc((void **) &queryIndiceBuffer,
maxQueryNum * sizeof(uint32_t)));
CUDA_CHECK(cudaMemset(timestamps, 0, capacity * sizeof(uint8_t)));
}
template<typename KeyType, typename ElemType>
FIFOCache<KeyType, ElemType>::~FIFOCache() {
// TODO need to add CudaDeviceGuard
CUDA_CHECK(cudaFree(keys));
CUDA_CHECK(cudaFree(values));
CUDA_CHECK(cudaFree(timestamps));
CUDA_CHECK(cudaFree(bucketMutexes));
CUDA_CHECK(cudaFree(queryKeyBuffer));
CUDA_CHECK(cudaFree(queryIndiceBuffer));
}
template<typename KeyType, typename ElemType>
__global__ void GetInternal(KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t nbucket, uint32_t num_elem_per_value,
uint32_t num_query, KeyType *queries,
ElemType *results, bool *find_mask) {
ThreadCtx ctx{};
__shared__ KeyType blockQueryKeys[defaultNumWarpsPerBlock][warpsize];
__shared__ uint32_t blockBucketIds[defaultNumWarpsPerBlock][warpsize];
for (uint32_t offset = ctx.global_warp_idx * warpsize; offset < num_query;
offset += ctx.num_warps * warpsize) {
uint32_t n_query = min(warpsize, num_query - offset);
if (ctx.lane_id < n_query) {
uint32_t idx = offset + ctx.lane_id;
KeyType key = queries[idx];
size_t hash = getHash<KeyType>(key);
uint32_t bucket_id = hash % nbucket;
blockQueryKeys[ctx.block_warp_idx][ctx.lane_id] = key;
blockBucketIds[ctx.block_warp_idx][ctx.lane_id] = bucket_id;
//printf("Get threadId %u map key %u to bucket %u\n",ctx.lane_id,key,bucket_id);
}
__syncwarp();
// 32 threads compare it own slot with key
// if find parallel write to result
for (uint32_t i = 0; i < n_query; i++) {
uint32_t idx = offset + i;
KeyType key = blockQueryKeys[ctx.block_warp_idx][i];
uint32_t bucket_id = blockBucketIds[ctx.block_warp_idx][i];
auto bucket = setBucketView<KeyType, ElemType>(
ctx, cache_keys, cache_values, cache_timestamps, cache_mutexes,
num_elem_per_value, bucket_id);
bucket.mutex->Lock(ctx, bucket_id);
int slot_num = bucket.Get(ctx, key);
if (slot_num != -1) {
// if(ctx.lane_id == 0){
// printf("read key %u from bucket %u slot_num %u\n",key,bucket_id,slot_num);
// }
find_mask[idx] = true;
bucket.ReadOneValue(ctx, slot_num, &results[idx * num_elem_per_value]);
}
bucket.mutex->UnLock(ctx);
}
__syncwarp();
}
}
template<typename KeyType, typename ElemType>
void
FIFOCache<KeyType, ElemType>::Get(cudaStream_t stream, uint32_t num_query, KeyType *queries, ElemType *results,
bool *find_mask) {
assert(num_query <= maxQueryNum && "num_query should less than maxQueryNum");
if (num_query == 0) {
return;
}
dim3 block(defaultBlockX);
dim3 grid((num_query + defaultBlockX - 1) / defaultBlockX);
if (std::is_same<KeyType, int32_t>::value) {
//std::cout << "fifo_get_internal_int32" << "\n";
GetInternal<KeyType, ElemType><<<grid, block, uint32SharedMemorySize, stream>>>(
keys, values, timestamps, bucketMutexes, nbucket, numElemPerValue,
num_query, queries, results, find_mask);
} else if (std::is_same<KeyType, int64_t>::value) {
//std::cout << "fifo_get_internal_int64" << "\n";
GetInternal<KeyType, ElemType><<<grid, block, uint64SharedMemorySize, stream>>>(
keys, values, timestamps, bucketMutexes, nbucket, numElemPerValue,
num_query, queries, results, find_mask);
} else {
//std::cout << "unsupported keytype" << std::endl;
throw "Unsupported KeyType!";
}
}
template<typename KeyType, typename ElemType>
__global__ void
PutWithoutEvictInternal(KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t nbucket, uint32_t num_elem_per_value,
uint32_t num_query, KeyType *put_keys,
ElemType *put_values, uint32_t *n_missing,
KeyType *missing_keys, uint32_t *missing_indices) {
ThreadCtx ctx{};
__shared__ KeyType blockPutKeys[defaultNumWarpsPerBlock][warpsize];
__shared__ uint32_t blockBucketIds[defaultNumWarpsPerBlock][warpsize];
for (uint32_t offset = ctx.global_warp_idx * warpsize; offset < num_query;
offset += ctx.num_warps * warpsize) {
uint32_t n_query = min(warpsize, num_query - offset);
if (ctx.lane_id < n_query) {
uint32_t idx = offset + ctx.lane_id;
KeyType key = put_keys[idx];
size_t hash = getHash(key);
uint32_t bucket_id = hash % nbucket;
blockPutKeys[ctx.block_warp_idx][ctx.lane_id] = key;
blockBucketIds[ctx.block_warp_idx][ctx.lane_id] = bucket_id;
// printf("Put threadId %u map key %u to bucket %u\n",ctx.lane_id,key,bucket_id);
}
__syncwarp();
uint32_t n_warp_missing = 0;
KeyType warp_missing_key;
uint32_t warp_missing_index = 0;
// 32 threads handle n_query keys together instead of 1 thread for 1 key
for (uint32_t i = 0; i < n_query; i++) {
uint32_t idx = offset + i;
KeyType key = blockPutKeys[ctx.block_warp_idx][i];
// ElemType* Value = &put_values[idx];
uint32_t bucket_id = blockBucketIds[ctx.block_warp_idx][i];
auto bucket = setBucketView<KeyType, ElemType>(
ctx, cache_keys, cache_values, cache_timestamps, cache_mutexes,
num_elem_per_value, bucket_id);
bucket.mutex->Lock(ctx, bucket_id);
int slot_num = bucket.TryPut(ctx, key);
if (slot_num != -1) { // insert value
// if(ctx.lane_id == 0){
// printf("warp %u put key %u value %u in slot %d\n",ctx.block_warp_idx, key,
// put_values[idx*num_elem_per_value], slot_num);
// }
bucket.WriteOneValue(ctx, slot_num,
&put_values[idx * num_elem_per_value]);
} else { // bucket full record missing_key + idx for Evict
if (ctx.lane_id ==
n_warp_missing) { // i-th thread keep i-th missing key
// printf("warp %u lane_id %u try put key %u fail\n",
// ctx.global_warp_idx, ctx.lane_id, key);
warp_missing_key = key;
warp_missing_index = idx;
}
n_warp_missing += 1;
}
bucket.mutex->UnLock(ctx);
}
// reduce missing_key & idx
if (n_warp_missing > 0) {
uint32_t base_missing_idx = 0;
if (ctx.lane_id == 0) {
base_missing_idx = atomicAdd(n_missing, n_warp_missing);
}
base_missing_idx = __shfl_sync(warpFullMask, base_missing_idx, 0);
if (ctx.lane_id < n_warp_missing) {
// printf("warp %u thread %u add missing_key %u index %u in idx
// %u\n",ctx.global_warp_idx, ctx.lane_id,warp_missing_key,
// warp_missing_index,base_missing_idx + ctx.lane_id);
missing_keys[base_missing_idx + ctx.lane_id] = warp_missing_key;
missing_indices[base_missing_idx + ctx.lane_id] = warp_missing_index;
}
}
__syncwarp();
}
}
template<typename KeyType, typename ElemType>
__global__ void
EvictInternal(KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes, uint32_t nbucket,
uint32_t num_elem_per_value, ElemType *put_values,
uint32_t n_missing, KeyType *missing_keys,
uint32_t *missing_indices, uint32_t *num_evict,
KeyType *evict_keys) {
ThreadCtx ctx{};
__shared__ KeyType blockPutKeys[defaultNumWarpsPerBlock][warpsize];
__shared__ uint32_t blockBucketIds[defaultNumWarpsPerBlock][warpsize];
for (uint32_t offset = ctx.global_warp_idx * warpsize; offset < n_missing;
offset += ctx.num_warps * warpsize) {
uint32_t n_evict = min(warpsize, n_missing - offset);
if (ctx.lane_id < n_evict) {
uint32_t idx = offset + ctx.lane_id;
KeyType key = missing_keys[idx];
size_t hash = getHash(missing_keys[idx]);
uint32_t bucket_id = hash % nbucket;
blockPutKeys[ctx.block_warp_idx][ctx.lane_id] = key;
blockBucketIds[ctx.block_warp_idx][ctx.lane_id] = bucket_id;
}
__syncwarp();
for (uint32_t i = 0; i < n_evict; i++) {
uint32_t idx = offset + i;
KeyType key = blockPutKeys[ctx.block_warp_idx][i];
uint32_t bucket_id = blockBucketIds[ctx.block_warp_idx][i];
auto bucket = setBucketView<KeyType, ElemType>(
ctx, cache_keys, cache_values, cache_timestamps, cache_mutexes,
num_elem_per_value, bucket_id);
bucket.mutex->Lock(ctx, bucket_id);
KeyType evict_key = 0;
int slot_num = bucket.Evict(ctx, key, &evict_key);
uint32_t evict_keys_idx = 0;
if (num_evict && ctx.lane_id == slot_num) {
evict_keys_idx = atomicAdd(num_evict, 1);
evict_keys[evict_keys_idx] = evict_key;
}
bucket.WriteOneValue(ctx, slot_num,
put_values +
missing_indices[idx] * num_elem_per_value);
bucket.mutex->UnLock(ctx);
}
}
}
template<typename KeyType, typename ElemType>
void FIFOCache<KeyType, ElemType>::Put(cudaStream_t stream, uint32_t num_query,KeyType *put_keys,
ElemType *put_values, uint32_t *n_evict,KeyType *evict_keys) {
assert(num_query <= maxQueryNum);
if (num_query == 0) {
return;
}
dim3 block(defaultBlockX);
dim3 grid((num_query + defaultBlockX - 1) / defaultBlockX);
uint32_t *n_missing;
CUDA_CHECK(cudaMallocManaged(&n_missing, sizeof(uint32_t)));
CUDA_CHECK(cudaMemset(n_missing, 0, sizeof(uint32_t)));
if (std::is_same<KeyType, int32_t>::value || std::is_same<KeyType, uint32_t>::value) {
PutWithoutEvictInternal < KeyType, ElemType >
<<<grid, block, uint32SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, num_query, put_keys, put_values,
n_missing, queryKeyBuffer, queryIndiceBuffer);
CUDA_CHECK(cudaStreamSynchronize(stream));
EvictInternal < KeyType, ElemType >
<<<grid, block, uint32SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, put_values, *n_missing,
queryKeyBuffer,
queryIndiceBuffer, n_evict, evict_keys);
} else if (std::is_same<KeyType, int64_t>::value || std::is_same<KeyType, uint64_t>::value) {
PutWithoutEvictInternal < KeyType, ElemType >
<<<grid, block, uint64SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, num_query, put_keys, put_values,
n_missing, queryKeyBuffer, queryIndiceBuffer);
CUDA_CHECK(cudaStreamSynchronize(stream));
EvictInternal < KeyType, ElemType >
<<<grid, block, uint64SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, put_values, *n_missing,
queryKeyBuffer,
queryIndiceBuffer, n_evict, evict_keys);
} else {
// std::cout << "unsupported keytype" << std::endl;
throw "Unsupported KeyType";
}
// printf("n_missing = %u\n",*n_missing);
CUDA_CHECK(cudaFree(n_missing));
}
template<typename KeyType, typename ElemType>
void FIFOCache<KeyType, ElemType>::Clear() {
CUDA_CHECK(cudaMemset(keys, 0, capacity * sizeof(KeyType)));
CUDA_CHECK(cudaMemset(timestamps, 0, capacity * sizeof(uint8_t)));
dim3 block(defaultBlockX);
dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX);
initLocks<<<grid, block>>>(nbucket, bucketMutexes);
}
}//namespace fifocache
}// namespace gpucache
namespace gpucache{
namespace fifocache{
template<typename ElemType>
void NewInt32FIFOCache(void **cache, CacheConfig cfg) {
auto c = new FIFOCache<int32_t, ElemType>(cfg);
//std::cout << "NewInt32FIFOCache FIFOCache keysize " << c->keySize << std::endl;
*cache = reinterpret_cast<void *>(c);
}
template<typename ElemType>
void NewInt64FIFOCache(void **cache, CacheConfig cfg) {
auto c = new FIFOCache<int64_t, ElemType>(cfg);
*cache = reinterpret_cast<void *>(c);
}
FIFOCacheWrapper::FIFOCacheWrapper(torch::Tensor t, CacheConfig cfg) {
//std::cout << "FIFOCacheWrapper constructor" << std::endl;
dtype = t.scalar_type();
cache_cfg = cfg;
key_is_int32 = cfg.capacity < INT32_MAX;
if (key_is_int32) {
kdtype = torch::kInt32;
AT_DISPATCH_ALL_TYPES(t.scalar_type(), "int32_fifo_cache", [&] {
NewInt32FIFOCache<scalar_t>(&fifo_cache, cfg);
});
} else {
kdtype = torch::kInt64;
AT_DISPATCH_ALL_TYPES(t.scalar_type(), "int64_fifo_cache", [&] {
NewInt64FIFOCache<scalar_t>(&fifo_cache, cfg);
});
}
// AT_DISPATCH_ALL_TYPES(dtype,"fifo_test",[&]{
// auto c = reinterpret_cast<FIFOCache<int32_t,scalar_t>*>(fifo_cache);
// std::cout << "FIFOCacheWrapper constructor FIFOCache dim: " << c->Dim() << " keySize: " << c->KeySize() << "\n";
// });
// std::cout << "FIFOCacheWrapper constructor dtype: " << dtype << " kdtype: " << kdtype << "\n";
}
FIFOCacheWrapper::~FIFOCacheWrapper() {
if (key_is_int32) {
AT_DISPATCH_ALL_TYPES(dtype, "int32_fifo_destructor", [&] {
auto cache = reinterpret_cast<FIFOCache<int32_t, scalar_t> *>(fifo_cache);
delete cache;
});
} else {
AT_DISPATCH_ALL_TYPES(dtype, "int64_fifo_destructor", [&] {
auto cache = reinterpret_cast<FIFOCache<int64_t, scalar_t> *>(fifo_cache);
delete cache;
});
}
//std::cout << "~FIFOCacheWrapper()" << std::endl;
}
std::pair<torch::Tensor, torch::Tensor> FIFOCacheWrapper::Get(uint32_t num_query, const torch::Tensor queries) {
assert(queries.device().index() == cache_cfg.deviceId &&
"query keys and cache should be in the same device");
assert(queries.size(0) == num_query && "num_query should equal queries.size(0)");
assert(num_query <= cache_cfg.maxQueryNum && "num_query must less than maxQueryNum");
assert(queries.dtype().toScalarType() == kdtype && "key type doesn't match");
//std::cout << "Get queries: " << queries << std::endl;
// printf("hello!!!\n");
//std::cout << "Get num_query: " << num_query <<"\n";
//std::cout << "Get queries.sizes() " << queries.sizes() << "\n";
auto keys = queries.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream(cache_cfg.deviceId).stream();
auto result = torch::empty({num_query, cache_cfg.dim},
torch::dtype(dtype).device(torch::kCUDA, cache_cfg.deviceId));
auto find_mask = torch::empty({num_query},
torch::dtype(torch::kBool).device(torch::kCUDA, cache_cfg.deviceId));
if (key_is_int32) {
AT_DISPATCH_ALL_TYPES(dtype, "int32_fifo_get", [&] {
FIFOCache<int32_t, scalar_t> *cache = reinterpret_cast<FIFOCache<int32_t, scalar_t> *>(fifo_cache);
// cache->Test();
//std::cout << "keysize: " << cache->KeySize() << "\n";
//std::cout << "dim: " << cache->Dim() << "\n";
cache->Get(stream, num_query, reinterpret_cast<int32_t *>(keys),
reinterpret_cast<scalar_t *>(result.data_ptr()),
reinterpret_cast<bool *>(find_mask.data_ptr()));
});
} else {
//std::cout << "int64_fifo_get" << "\n";
AT_DISPATCH_ALL_TYPES(dtype, "int64_fifo_get", [&] {
FIFOCache<int64_t, scalar_t> *cache = reinterpret_cast<FIFOCache<int64_t, scalar_t> *>(fifo_cache);
//std::cout << "cast success!" << std::endl;
cache->Get(stream, num_query, reinterpret_cast<int64_t *>(keys),
reinterpret_cast<scalar_t *>(result.data_ptr()),
reinterpret_cast<bool *>(find_mask.data_ptr()));
});
}
return std::make_pair(result, find_mask);
}
void FIFOCacheWrapper::Put(uint32_t num_query, const torch::Tensor keys, const torch::Tensor values) {
assert(keys.device().index() == cache_cfg.deviceId && cache_cfg.deviceId == values.device().index() &&
"put keys, put values and cache should be in the same device");
assert(num_query == keys.size(0) && keys.size(0) == values.size(0) &&
"dim 0 of keys,values and num_query should be equal");
assert(num_query <= cache_cfg.maxQueryNum && "num_query must less than maxQueryNum");
assert(values.dtype() == dtype);
auto stream = at::cuda::getCurrentCUDAStream(cache_cfg.deviceId).stream();
if (key_is_int32) {
//std::cout << "int32_fifo_put" << "\n";
AT_DISPATCH_ALL_TYPES(dtype, "int32_fifo_put", [&] {
auto cache = reinterpret_cast<FIFOCache<int32_t, scalar_t> *>(fifo_cache);
//std::cout << "put_fifo_dim" << cache->Dim() << std::endl;
cache->Put(stream, num_query, reinterpret_cast<int32_t *>(keys.data_ptr()),
reinterpret_cast<scalar_t *>(values.data_ptr()));
});
} else {
//std::cout << "int64_fifo_put" << "\n";
AT_DISPATCH_ALL_TYPES(dtype, "int64_fifo_put", [&] {
auto cache = reinterpret_cast<FIFOCache<int64_t, scalar_t> *>(fifo_cache);
cache->Put(stream, num_query, reinterpret_cast<int64_t *>(keys.data_ptr()),
reinterpret_cast<scalar_t *>(values.data_ptr()));
});
}
}
CacheConfig::CacheEvictStrategy FIFOCacheWrapper::Strategy() {
return CacheConfig::CacheEvictStrategy::FIFO;
}
uint64_t FIFOCacheWrapper::Capacity() { return cache_cfg.capacity; }
uint32_t FIFOCacheWrapper::KeySize() { return cache_cfg.keySize; }
uint32_t FIFOCacheWrapper::ValueSize() { return cache_cfg.valueSize; }
uint32_t FIFOCacheWrapper::MaxQueryNum() { return cache_cfg.maxQueryNum; }
uint64_t FIFOCacheWrapper::DeviceId() { return cache_cfg.deviceId; }
uint32_t FIFOCacheWrapper::Dim() { return cache_cfg.dim; }
void FIFOCacheWrapper::Clear() {
if (key_is_int32) {
AT_DISPATCH_ALL_TYPES(dtype, "int32_fifo_clear", [&] {
auto cache = reinterpret_cast<FIFOCache<int32_t, scalar_t> *>(fifo_cache);
cache->Clear();
});
} else {
AT_DISPATCH_ALL_TYPES(dtype, "int64_fifo_clear", [&] {
auto cache = reinterpret_cast<FIFOCache<int64_t, scalar_t> *>(fifo_cache);
cache->Clear();
});
}
}
std::unique_ptr<FIFOCacheWrapper> NewFIFOCache(at::Tensor t, CacheConfig cfg) {
return std::make_unique<FIFOCacheWrapper>(t, cfg);
}
}
}
\ No newline at end of file
#pragma once
#include <torch/extension.h>
#include "cache.h"
namespace gpucache {
namespace fifocache {
class FIFOCacheWrapper {
public:
FIFOCacheWrapper(at::Tensor t, CacheConfig cfg);
~FIFOCacheWrapper();
std::pair<torch::Tensor, torch::Tensor> Get(uint32_t num_query, const torch::Tensor queries);
void Put(uint32_t num_query, const torch::Tensor keys, const torch::Tensor values);
CacheConfig::CacheEvictStrategy Strategy();
uint64_t Capacity();
uint32_t KeySize();
uint32_t ValueSize();
uint32_t MaxQueryNum();
uint64_t DeviceId();
uint32_t Dim();
void Clear();
private:
void *fifo_cache;
c10::ScalarType dtype; // value dtype
c10::ScalarType kdtype; // key dtype
bool key_is_int32; // only support int32 and int64
CacheConfig cache_cfg;
};
std::unique_ptr<FIFOCacheWrapper> NewFIFOCache(at::Tensor t, CacheConfig cfg);
} // namespace lrucache
} // namespace gpucache
\ No newline at end of file
#pragma once
#include <string>
#include "murmurhash3.cuh"
#include <cuda_runtime.h>
constexpr uint32_t lruSeed = 0X12fb73ac;
__device__ void MurmurHash3_x86_32 ( const void * key, int len,
uint32_t seed, void * out );
template<typename T>
__device__ size_t getHash(const T& obj){
size_t hash;
MurmurHash3_x86_32(reinterpret_cast<const void*>(&obj),sizeof(T),lruSeed,reinterpret_cast<void*>(&hash));
return hash;
}
}
\ No newline at end of file
......@@ -11,7 +11,6 @@
// Platform-specific functions and macros
// Microsoft Visual Studio
#if defined(_MSC_VER)
#define FORCE_INLINE __forceinline
......
#include "hash/hash_fucntion.cuh"
#include "common.cuh"
#include "lru_cache.h"
#include <vector>
#include <string_view>
#include <memory>
//template <typename T>
//constexpr auto type_name() {
// std::string_view name, prefix, suffix;
//#ifdef __clang__
// name = __PRETTY_FUNCTION__;
// prefix = "auto type_name() [T = ";
// suffix = "]";
//#elif defined(__GNUC__)
// name = __PRETTY_FUNCTION__;
// prefix = "constexpr auto type_name() [with T = ";
// suffix = "]";
//#elif defined(_MSC_VER)
// name = __FUNCSIG__;
// prefix = "auto __cdecl type_name<";
// suffix = ">(void)";
//#endif
// name.remove_prefix(prefix.size());
// name.remove_suffix(suffix.size());
// return name;
//}
namespace gpucache {
constexpr unsigned int defaultBlockX = 256;
constexpr unsigned int warpsize = 32;
constexpr unsigned int defaultNumWarpsPerBlock = defaultBlockX / warpsize;
// bucket_id + key
constexpr unsigned int uint32SharedMemorySize = 2 * sizeof(uint32_t) * defaultNumWarpsPerBlock * warpsize;
constexpr unsigned int uint64SharedMemorySize =
(sizeof(uint64_t) + sizeof(uint32_t)) * defaultNumWarpsPerBlock * warpsize;
struct ThreadCtx {
__device__ ThreadCtx() {
auto global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
global_warp_idx = global_thread_id / warpsize;
block_warp_idx = threadIdx.x / warpsize;
lane_id = threadIdx.x % warpsize;
num_warps = blockDim.x * gridDim.x / warpsize;
}
namespace lrucache {
// bucket view for warp put/get
template<typename KeyType, typename ElemType>
struct BucketView {
__device__ BucketView(KeyType *k, ElemType *v, uint8_t *ts, WarpMutex *m,
uint32_t num_elems_per_value)
: mutex(m), bkeys(k), bvalues(v), btimestamps(ts),
num_elems_per_value(num_elems_per_value) {}
__device__ int Get(const ThreadCtx &ctx, const KeyType key) {
KeyType lane_key = bkeys[ctx.lane_id];
uint8_t ts = btimestamps[ctx.lane_id];
bool exist = (ts > 0 && lane_key == key);
uint32_t exist_mask = __ballot_sync(warpFullMask, exist);
int slot_num = __ffs(exist_mask) - 1;
if (exist_mask == 0) { // not found
return -1;
} else {
auto slot_ts = __shfl_sync(warpFullMask, ts, slot_num);
if (ts > slot_ts) {
btimestamps[ctx.lane_id]--;
}
if (ctx.lane_id == slot_num) {
btimestamps[ctx.lane_id] = warpsize;
}
return slot_num; // return exist slot num
};
}
uint32_t global_warp_idx;
uint32_t block_warp_idx;
uint32_t num_warps;
uint32_t lane_id;
};
// for test
template<typename KeyType>
__global__ void CollectMissingKeysNew(uint32_t num_query, KeyType *keys,
bool *find_mask, uint32_t *n_missing,
KeyType *missing_keys) {
ThreadCtx ctx{};
for (auto offset = ctx.global_warp_idx * warpsize; offset < num_query;
offset += ctx.num_warps * warpsize) {
auto idx = offset + ctx.lane_id;
if (!find_mask[idx]) {
uint32_t base_missing_idx = atomicAdd(n_missing, 1);
missing_keys[base_missing_idx] = keys[idx];
__device__ int TryPut(const ThreadCtx &ctx, const KeyType key) {
KeyType lane_key = bkeys[ctx.lane_id];
uint8_t ts = btimestamps[ctx.lane_id];
bool exist = (ts > 0 && lane_key == key);
uint32_t exist_mask = __ballot_sync(warpFullMask, exist);
int slot_num = __ffs(exist_mask) - 1;
if (exist_mask != 0) { // find key update value and ts
uint32_t slot_ts = __shfl_sync(warpFullMask, ts, slot_num);
if (ts > slot_ts) {
ts--;
} else if (ctx.lane_id == slot_num) {
ts = warpsize;
}
__syncwarp();
}
if (slot_num == -1) { // key don't exist
uint32_t emptyMask = __ballot_sync(warpFullMask, ts != 0);
if (emptyMask != warpFullMask) { // bucket not full
slot_num = __popc(emptyMask);
if (ts > 0) {
ts--;
} else if (ctx.lane_id == slot_num) {
bkeys[slot_num] = key; // insert key
ts = warpsize;
}
}
__syncwarp();
}
if (slot_num != -1) {
btimestamps[ctx.lane_id] = ts;
}
return slot_num;
}
}
}
struct WarpMutex {
public:
__device__ WarpMutex() : flag(0) {}
__device__ int Evict(const ThreadCtx &ctx, const KeyType key,
KeyType *evict_key) {
KeyType lane_key = bkeys[ctx.lane_id];
uint8_t ts = btimestamps[ctx.lane_id];
uint32_t evict_mask = __ballot_sync(warpFullMask, ts == 1);
int slot_num = __ffs(evict_mask) - 1;
if (ts > 1) {
ts--;
} else if (ctx.lane_id == slot_num) {
// printf("evict key %u\n",lane_key);
*evict_key = lane_key;
// printf("evict evict key %u\n",*evict_key);
bkeys[slot_num] = key;
ts = warpsize;
}
btimestamps[ctx.lane_id] = ts;
return slot_num;
}
~WarpMutex() = default;
__device__ void ReadOneValue(const ThreadCtx &ctx, uint8_t slot_num,
ElemType *out) {
for (size_t i = ctx.lane_id; i < num_elems_per_value; i += warpsize) {
out[i] = bvalues[slot_num * num_elems_per_value + i];
}
}
WarpMutex(const WarpMutex &) = delete;
__device__ void WriteOneValue(const ThreadCtx &ctx, uint8_t slot_num,
ElemType *v) {
for (size_t i = ctx.lane_id; i < num_elems_per_value; i += warpsize) {
bvalues[slot_num * num_elems_per_value + i] = v[i];
}
}
WarpMutex &operator=(const WarpMutex &) = delete;
WarpMutex *mutex;
KeyType *bkeys;
ElemType *bvalues;
uint8_t *btimestamps;
uint32_t num_elems_per_value;
};
WarpMutex(WarpMutex &&) = delete;
template<typename KeyType, typename ElemType>
__device__ __host__ BucketView<KeyType, ElemType>
setBucketView(ThreadCtx ctx, KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t num_elem_per_value, uint32_t bucket_id);
WarpMutex &operator=(WarpMutex &&) = delete;
template<typename KeyType, typename ElemType>
class LRUCache : public Cache<KeyType, ElemType> {
__device__ void Lock(ThreadCtx &ctx, uint32_t bucket_id) {
if (ctx.lane_id == 0) {
while (atomicCAS(&flag, 0, 1) != 0) {}
}
__threadfence();
__syncwarp();
}
friend BucketView<KeyType, ElemType> __device__ __host__ setBucketView<KeyType, ElemType>(
ThreadCtx ctx, KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t num_elem_per_value, uint32_t bucket_id);
__device__ void UnLock(ThreadCtx &ctx) {
__syncwarp();
__threadfence();
if (ctx.lane_id == 0) {
auto i = atomicExch(&flag, 0);
// printf("bucket id: Release Lock, flag = %u\n", i);
}
}
public:
explicit LRUCache(const CacheConfig &cfg);
// private:
uint32_t flag;
};
__global__ void initLocks(uint32_t n_bucket, void *bucketMutexes) {
uint32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (global_thread_idx < n_bucket) {
new(reinterpret_cast<WarpMutex *>(bucketMutexes) + global_thread_idx)
WarpMutex();
}
}
__global__ void checkLocks(uint32_t n_bucket, void *bucketMutexes) {
uint32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
// printf("thread %u CUDA_CHECK lock\n",global_thread_idx);
if (global_thread_idx < n_bucket) {
auto mutex =
reinterpret_cast<WarpMutex *>(bucketMutexes) + global_thread_idx;
if (mutex->flag != 0u && mutex->flag != 1u) {
printf("bucket id %u not equal 0 or 1, is %u\n", global_thread_idx,
mutex->flag);
}
}
}
~LRUCache();
template<typename KeyType, typename ElemType>
class LRUCache;
template<typename KeyType, typename ElemType>
struct BucketView;
uint32_t KeySize() override;
template<typename KeyType, typename ElemType>
__device__ __host__ BucketView<KeyType, ElemType>
setBucketView(ThreadCtx ctx, KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t num_elem_per_value, uint32_t bucket_id);
uint32_t ValueSize() override;
template<typename KeyType, typename ElemType>
class LRUCache : public Cache<KeyType, ElemType> {
uint64_t Capacity() override;
friend BucketView<KeyType, ElemType> __device__ __host__ setBucketView<KeyType, ElemType>(
ThreadCtx ctx, KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t num_elem_per_value, uint32_t bucket_id);
uint32_t NumElemsPerValue() override;
public:
explicit LRUCache(const CacheConfig &cfg);
uint32_t MaxQueryNum() override;
~LRUCache();
int8_t DeviceId() override;
uint32_t Dim() override;
uint32_t KeySize() override;
CacheConfig::CacheEvictStrategy Strategy() override;
uint32_t ValueSize() override;
void Clear() override;
uint64_t Capacity() override;
void Get(cudaStream_t stream, uint32_t num_query, KeyType *queries,
ElemType *results, bool *find_mask) override;
uint32_t NumElemsPerValue() override;
void Put(cudaStream_t stream, uint32_t num_query, KeyType *put_keys,
ElemType *put_values, uint32_t *n_evict = nullptr,
KeyType *evict_keys = nullptr) override;
uint32_t MaxQueryNum() override;
private:
KeyType *keys;
ElemType *values;
uint8_t *timestamps{};
uint32_t nbucket; // 32 values for one bucket
void *bucketMutexes{};
uint32_t NBucket();
// CacheConfig::CacheEvictStrategy strategy;
uint64_t capacity;
uint32_t keySize;
uint32_t valueSize;
uint32_t numElemPerValue; // embedding dim
int8_t device_id;
uint32_t dim;
int8_t DeviceId() override;
// store missing keys and indices for Evict
KeyType *queryKeyBuffer{};
uint32_t *queryIndiceBuffer{};
uint32_t maxQueryNum;
};
uint32_t Dim() override;
void Test();
template<typename KeyType, typename ElemType>
LRUCache<KeyType, ElemType>::LRUCache(const CacheConfig &cfg)
: keySize(cfg.keySize), valueSize(cfg.valueSize), capacity(cfg.capacity),
maxQueryNum(cfg.maxQueryNum), device_id(cfg.deviceId), dim(cfg.dim) {
// std::cout << "LRUCache constructor" << std::endl;
numElemPerValue = valueSize / sizeof(ElemType);
nbucket = (capacity + warpsize - 1) / warpsize;
printf("LRUCache: keySize: %lu, valueSize: %u, dim: %u, capacity: %lu, "
"maxQueryNum: %u, deviceId: %u\n",
sizeof(KeyType), valueSize, dim, capacity, maxQueryNum, device_id);
CUDA_CHECK(cudaMalloc((void **) &keys, capacity * sizeof(KeyType)));
CUDA_CHECK(cudaMalloc((void **) &values, capacity * valueSize));
CUDA_CHECK(cudaMalloc((void **) &timestamps, capacity * sizeof(uint8_t)));
CUDA_CHECK(cudaMalloc((void **) &bucketMutexes, nbucket * sizeof(WarpMutex)));
dim3 block(defaultBlockX);
dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX);
initLocks<<<grid, block>>>(nbucket, bucketMutexes);
CUDA_CHECK(cudaMalloc((void **) &queryKeyBuffer, maxQueryNum * sizeof(KeyType)));
CUDA_CHECK(cudaMalloc((void **) &queryIndiceBuffer,
maxQueryNum * sizeof(uint32_t)));
}
// for test
// void *Mutex() { return bucketMutexes; }
template<typename KeyType, typename ElemType>
void LRUCache<KeyType, ElemType>::Clear() {
CUDA_CHECK(cudaMemset(keys, 0, capacity * sizeof(KeyType)));
CUDA_CHECK(cudaMemset(timestamps, 0, capacity * sizeof(uint8_t)));
dim3 block(defaultBlockX);
dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX);
initLocks<<<grid, block>>>(nbucket, bucketMutexes);
}
CacheConfig::CacheEvictStrategy Strategy() override;
template<typename KeyType, typename ElemType>
LRUCache<KeyType, ElemType>::~LRUCache() {
// TODO need to add CudaDeviceGuard
CUDA_CHECK(cudaFree(keys));
CUDA_CHECK(cudaFree(values));
CUDA_CHECK(cudaFree(timestamps));
CUDA_CHECK(cudaFree(bucketMutexes));
CUDA_CHECK(cudaFree(queryKeyBuffer));
CUDA_CHECK(cudaFree(queryIndiceBuffer));
// std::cout << "LRUCache destructed!" << std::endl;
}
void Clear() override;
template<typename KeyType, typename ElemType>
uint32_t LRUCache<KeyType, ElemType>::KeySize() { return keySize; }
void Get(cudaStream_t stream, uint32_t num_query, KeyType *queries,
ElemType *results, bool *find_mask) override;
template<typename KeyType, typename ElemType>
uint32_t LRUCache<KeyType, ElemType>::ValueSize() { return valueSize; }
void Put(cudaStream_t stream, uint32_t num_query, KeyType *put_keys,
ElemType *put_values, uint32_t *n_evict = nullptr,
KeyType *evict_keys = nullptr) override;
template<typename KeyType, typename ElemType>
uint64_t LRUCache<KeyType, ElemType>::Capacity() { return capacity; }
// private:
KeyType *keys;
ElemType *values;
uint8_t *timestamps{};
uint32_t nbucket; // 32 values for one bucket
void *bucketMutexes{};
template<typename KeyType, typename ElemType>
uint32_t LRUCache<KeyType, ElemType>::NumElemsPerValue() { return numElemPerValue; }
// CacheConfig::CacheEvictStrategy strategy;
uint64_t capacity;
uint32_t keySize;
uint32_t valueSize;
uint32_t numElemPerValue; // embedding dim
int8_t device_id;
uint32_t dim;
template<typename KeyType, typename ElemType>
uint32_t LRUCache<KeyType, ElemType>::MaxQueryNum() { return maxQueryNum; }
// store missing keys and indices for Evict
KeyType *queryKeyBuffer{};
uint32_t *queryIndiceBuffer{};
uint32_t maxQueryNum;
};
template<typename KeyType, typename ElemType>
int8_t LRUCache<KeyType, ElemType>::DeviceId() { return device_id; }
template<typename KeyType, typename ElemType>
uint32_t LRUCache<KeyType, ElemType>::Dim() { return dim; }
template<typename KeyType, typename ElemType>
LRUCache<KeyType, ElemType>::LRUCache(const CacheConfig &cfg)
: keySize(cfg.keySize), valueSize(cfg.valueSize), capacity(cfg.capacity),
maxQueryNum(cfg.maxQueryNum), device_id(cfg.deviceId), dim(cfg.dim) {
// std::cout << "LRUCache constructor" << std::endl;
numElemPerValue = valueSize / sizeof(ElemType);
nbucket = (capacity + warpsize - 1) / warpsize;
printf("LRUCache: keySize: %lu, valueSize: %u, dim: %u, capacity: %lu, "
"maxQueryNum: %u, deviceId: %u\n",
sizeof(KeyType), valueSize, dim, capacity, maxQueryNum, device_id);
CUDA_CHECK(cudaMalloc((void **) &keys, capacity * sizeof(KeyType)));
CUDA_CHECK(cudaMalloc((void **) &values, capacity * valueSize));
CUDA_CHECK(cudaMalloc((void **) &timestamps, capacity * sizeof(uint8_t)));
CUDA_CHECK(cudaMalloc((void **) &bucketMutexes, nbucket * sizeof(WarpMutex)));
dim3 block(defaultBlockX);
dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX);
initLocks<<<grid, block>>>(nbucket, bucketMutexes);
CUDA_CHECK(cudaMalloc((void **) &queryKeyBuffer, maxQueryNum * sizeof(KeyType)));
CUDA_CHECK(cudaMalloc((void **) &queryIndiceBuffer,
maxQueryNum * sizeof(uint32_t)));
}
template<typename KeyType, typename ElemType>
void LRUCache<KeyType, ElemType>::Clear() {
CUDA_CHECK(cudaMemset(keys, 0, capacity * sizeof(KeyType)));
CUDA_CHECK(cudaMemset(timestamps, 0, capacity * sizeof(uint32_t)));
dim3 block(defaultBlockX);
dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX);
initLocks<<<grid, block>>>(nbucket, bucketMutexes
);
}
template<typename KeyType, typename ElemType>
LRUCache<KeyType, ElemType>::~LRUCache() {
// TODO need to add CudaDeviceGuard
CUDA_CHECK(cudaFree(keys));
CUDA_CHECK(cudaFree(values));
CUDA_CHECK(cudaFree(timestamps));
CUDA_CHECK(cudaFree(bucketMutexes));
CUDA_CHECK(cudaFree(queryKeyBuffer));
CUDA_CHECK(cudaFree(queryIndiceBuffer));
// std::cout << "LRUCache destructed!" << std::endl;
}
template<typename KeyType, typename ElemType>
uint32_t LRUCache<KeyType, ElemType>::KeySize() { return keySize; }
template<typename KeyType, typename ElemType>
uint32_t LRUCache<KeyType, ElemType>::ValueSize() { return valueSize; }
template<typename KeyType, typename ElemType>
uint64_t LRUCache<KeyType, ElemType>::Capacity() { return capacity; }
template<typename KeyType, typename ElemType>
uint32_t LRUCache<KeyType, ElemType>::NumElemsPerValue() { return numElemPerValue; }
template<typename KeyType, typename ElemType>
uint32_t LRUCache<KeyType, ElemType>::MaxQueryNum() { return maxQueryNum; }
template<typename KeyType, typename ElemType>
uint32_t LRUCache<KeyType, ElemType>::NBucket() { return nbucket; }
template<typename KeyType, typename ElemType>
int8_t LRUCache<KeyType, ElemType>::DeviceId() { return device_id; }
template<typename KeyType, typename ElemType>
uint32_t LRUCache<KeyType, ElemType>::Dim() { return dim; }
template<typename KeyType, typename ElemType>
CacheConfig::CacheEvictStrategy LRUCache<KeyType, ElemType>::Strategy() {
return CacheConfig::LRU;
}
// bucket view for warp put/get
template<typename KeyType, typename ElemType>
struct BucketView {
__device__ BucketView(KeyType *k, ElemType *v, uint8_t *ts, WarpMutex *m,
uint32_t num_elems_per_value)
: mutex(m), bkeys(k), bvalues(v), btimestamps(ts),
num_elems_per_value(num_elems_per_value) {}
__device__ int Get(const ThreadCtx &ctx, const KeyType key) {
KeyType lane_key = bkeys[ctx.lane_id];
uint8_t ts = btimestamps[ctx.lane_id];
bool exist = (ts > 0 && lane_key == key);
uint32_t exist_mask = __ballot_sync(0xFFFFFFFF, exist);
int slot_num = __ffs(static_cast<int>(exist_mask)) - 1;
if (exist_mask == 0) { // not found
return -1;
} else {
auto slot_ts = __shfl_sync(0xFFFFFFFF, ts, slot_num);
if (ts > slot_ts) {
btimestamps[ctx.lane_id]--;
}
if (ctx.lane_id == slot_num) {
btimestamps[ctx.lane_id] = warpsize;
}
return slot_num; // return exist slot num
};
template<typename KeyType, typename ElemType>
CacheConfig::CacheEvictStrategy LRUCache<KeyType, ElemType>::Strategy() {
return CacheConfig::LRU;
}
__device__ int TryPut(const ThreadCtx &ctx, const KeyType key) {
KeyType lane_key = bkeys[ctx.lane_id];
uint8_t ts = btimestamps[ctx.lane_id];
bool exist = (ts > 0 && lane_key == key);
uint32_t exist_mask = __ballot_sync(0xFFFFFFFF, exist);
int slot_num = __ffs(static_cast<int>(exist_mask)) - 1;
if (exist_mask != 0) { // find key just update value
uint32_t slot_ts = __shfl_sync(0xFFFFFFFF, ts, slot_num);
if (ts > slot_ts) {
ts--;
} else if (ctx.lane_id == slot_num) {
ts = warpsize;
}
__syncwarp();
}
if (slot_num == -1) { // key don't exist
uint32_t emptyMask = __ballot_sync(0xFFFFFFFF, ts != 0);
if (emptyMask != 0xFFFFFFFF) { // bucket not full
slot_num = __popc(static_cast<int>(emptyMask));
if (ts > 0) {
ts--;
} else if (ctx.lane_id == slot_num) {
bkeys[slot_num] = key; // insert key
ts = warpsize;
}
template<typename KeyType, typename ElemType>
__device__ __host__ BucketView<KeyType, ElemType>
setBucketView(ThreadCtx ctx, KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t num_elem_per_value, uint32_t bucket_id) {
return BucketView<KeyType, ElemType>(
cache_keys + bucket_id * warpsize,
cache_values + bucket_id * warpsize * num_elem_per_value,
cache_timestamps + bucket_id * warpsize,
reinterpret_cast<WarpMutex *>(cache_mutexes) + bucket_id,
num_elem_per_value);
}
template<typename KeyType, typename ElemType>
__global__ void GetInternal(KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t nbucket, uint32_t num_elem_per_value,
uint32_t num_query, KeyType *queries,
ElemType *results, bool *find_mask) {
ThreadCtx ctx{};
__shared__ KeyType blockQueryKeys[defaultNumWarpsPerBlock][warpsize];
__shared__ uint32_t blockBucketIds[defaultNumWarpsPerBlock][warpsize];
for (uint32_t offset = ctx.global_warp_idx * warpsize; offset < num_query;
offset += ctx.num_warps * warpsize) {
uint32_t n_query = min(warpsize, num_query - offset);
if (ctx.lane_id < n_query) {
uint32_t idx = offset + ctx.lane_id;
KeyType key = queries[idx];
size_t hash = getHash<KeyType>(key);
uint32_t bucket_id = hash % nbucket;
blockQueryKeys[ctx.block_warp_idx][ctx.lane_id] = key;
blockBucketIds[ctx.block_warp_idx][ctx.lane_id] = bucket_id;
}
__syncwarp();
}
if (slot_num != -1) {
btimestamps[ctx.lane_id] = ts;
}
return slot_num;
}
// 32 threads compare it own slot with key
// if find parallel write to result
for (uint32_t i = 0; i < n_query; i++) {
uint32_t idx = offset + i;
KeyType key = blockQueryKeys[ctx.block_warp_idx][i];
uint32_t bucket_id = blockBucketIds[ctx.block_warp_idx][i];
__device__ int Evict(const ThreadCtx &ctx, const KeyType key,
KeyType *evict_key) {
KeyType lane_key = bkeys[ctx.lane_id];
uint8_t ts = btimestamps[ctx.lane_id];
uint32_t evict_mask = __ballot_sync(0xFFFFFFFF, ts == 1);
int slot_num = __ffs(static_cast<int>(evict_mask)) - 1;
if (ts > 1) {
ts--;
} else if (ctx.lane_id == slot_num) {
// printf("evict key %u\n",lane_key);
*evict_key = lane_key;
// printf("evict evict key %u\n",*evict_key);
bkeys[slot_num] = key;
ts = warpsize;
}
btimestamps[ctx.lane_id] = ts;
return slot_num;
}
auto bucket = setBucketView<KeyType, ElemType>(
ctx, cache_keys, cache_values, cache_timestamps, cache_mutexes,
num_elem_per_value, bucket_id);
__device__ void ReadOneValue(const ThreadCtx &ctx, uint8_t slot_num,
ElemType *out) {
for (size_t i = ctx.lane_id; i < num_elems_per_value; i += warpsize) {
out[i] = bvalues[slot_num * num_elems_per_value + i];
}
}
bucket.mutex->Lock(ctx, bucket_id);
__device__ void WriteOneValue(const ThreadCtx &ctx, uint8_t slot_num,
ElemType *v) {
for (size_t i = ctx.lane_id; i < num_elems_per_value; i += warpsize) {
bvalues[slot_num * num_elems_per_value + i] = v[i];
int slot_num = bucket.Get(ctx, key);
if (slot_num != -1) {
find_mask[idx] = true;
bucket.ReadOneValue(ctx, slot_num, &results[idx * num_elem_per_value]);
}
bucket.mutex->UnLock(ctx);
}
__syncwarp();
}
}
WarpMutex *mutex;
KeyType *bkeys;
ElemType *bvalues;
uint8_t *btimestamps;
uint32_t num_elems_per_value;
};
template<typename KeyType, typename ElemType>
__device__ __host__ BucketView<KeyType, ElemType>
setBucketView(ThreadCtx ctx, KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t num_elem_per_value, uint32_t bucket_id) {
// auto p = reinterpret_cast<WarpMutex *>(cache_mutexes) + bucket_id;
// if (p->flag != 0 && p->flag != 1) {
// if (ctx.lane_id == 0) {
// printf("warp %u bucket %u flag not equal 0 or 1 is
// %u\n",ctx.global_warp_idx,bucket_id,p->flag);
// }
// }
return BucketView<KeyType, ElemType>(
cache_keys + bucket_id * warpsize,
cache_values + bucket_id * warpsize * num_elem_per_value,
cache_timestamps + bucket_id * warpsize,
reinterpret_cast<WarpMutex *>(cache_mutexes) + bucket_id,
num_elem_per_value);
}
template<typename KeyType, typename ElemType>
__global__ void GetInternal(KeyType *cache_keys, ElemType *cache_values,
template<typename KeyType, typename ElemType>
__global__ void
PutWithoutEvictInternal(KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t nbucket, uint32_t num_elem_per_value,
uint32_t num_query, KeyType *queries,
ElemType *results, bool *find_mask) {
ThreadCtx ctx{};
__shared__ KeyType blockQueryKeys[defaultNumWarpsPerBlock][warpsize];
__shared__ uint32_t blockBucketIds[defaultNumWarpsPerBlock][warpsize];
for (uint32_t offset = ctx.global_warp_idx * warpsize; offset < num_query;
offset += ctx.num_warps * warpsize) {
uint32_t n_query = min(warpsize, num_query - offset);
if (ctx.lane_id < n_query) {
uint32_t idx = offset + ctx.lane_id;
KeyType key = queries[idx];
size_t hash = getHash<KeyType>(key);
uint32_t bucket_id = hash % nbucket;
blockQueryKeys[ctx.block_warp_idx][ctx.lane_id] = key;
blockBucketIds[ctx.block_warp_idx][ctx.lane_id] = bucket_id;
}
__syncwarp();
// 32 threads compare it own slot with key
// if find parallel write to result
for (uint32_t i = 0; i < n_query; i++) {
uint32_t idx = offset + i;
KeyType key = blockQueryKeys[ctx.block_warp_idx][i];
uint32_t bucket_id = blockBucketIds[ctx.block_warp_idx][i];
auto bucket = setBucketView<KeyType, ElemType>(
ctx, cache_keys, cache_values, cache_timestamps, cache_mutexes,
num_elem_per_value, bucket_id);
bucket.mutex->Lock(ctx, bucket_id);
int slot_num = bucket.Get(ctx, key);
if (slot_num != -1) {
find_mask[idx] = true;
bucket.ReadOneValue(ctx, slot_num, &results[idx * num_elem_per_value]);
uint32_t num_query, KeyType *put_keys,
ElemType *put_values, uint32_t *n_missing,
KeyType *missing_keys, uint32_t *missing_indices) {
ThreadCtx ctx{};
__shared__ KeyType blockPutKeys[defaultNumWarpsPerBlock][warpsize];
__shared__ uint32_t blockBucketIds[defaultNumWarpsPerBlock][warpsize];
for (uint32_t offset = ctx.global_warp_idx * warpsize; offset < num_query;
offset += ctx.num_warps * warpsize) {
uint32_t n_query = min(warpsize, num_query - offset);
if (ctx.lane_id < n_query) {
uint32_t idx = offset + ctx.lane_id;
KeyType key = put_keys[idx];
size_t hash = getHash(key);
uint32_t bucket_id = hash % nbucket;
blockPutKeys[ctx.block_warp_idx][ctx.lane_id] = key;
blockBucketIds[ctx.block_warp_idx][ctx.lane_id] = bucket_id;
}
bucket.mutex->UnLock(ctx);
}
__syncwarp();
}
}
template<typename KeyType, typename ElemType>
__global__ void
PutWithoutEvictInternal(KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes,
uint32_t nbucket, uint32_t num_elem_per_value,
uint32_t num_query, KeyType *put_keys,
ElemType *put_values, uint32_t *n_missing,
KeyType *missing_keys, uint32_t *missing_indices) {
ThreadCtx ctx{};
__shared__ KeyType blockPutKeys[defaultNumWarpsPerBlock][warpsize];
__shared__ uint32_t blockBucketIds[defaultNumWarpsPerBlock][warpsize];
for (uint32_t offset = ctx.global_warp_idx * warpsize; offset < num_query;
offset += ctx.num_warps * warpsize) {
uint32_t n_query = min(warpsize, num_query - offset);
if (ctx.lane_id < n_query) {
uint32_t idx = offset + ctx.lane_id;
KeyType key = put_keys[idx];
size_t hash = getHash(key);
uint32_t bucket_id = hash % nbucket;
blockPutKeys[ctx.block_warp_idx][ctx.lane_id] = key;
blockBucketIds[ctx.block_warp_idx][ctx.lane_id] = bucket_id;
}
__syncwarp();
uint32_t n_warp_missing = 0;
KeyType warp_missing_key;
uint32_t warp_missing_index = 0;
// 32 threads handle n_query keys together instead of 1 thread for 1 key
for (uint32_t i = 0; i < n_query; i++) {
uint32_t idx = offset + i;
KeyType key = blockPutKeys[ctx.block_warp_idx][i];
// ElemType* Value = &put_values[idx];
uint32_t bucket_id = blockBucketIds[ctx.block_warp_idx][i];
auto bucket = setBucketView<KeyType, ElemType>(
ctx, cache_keys, cache_values, cache_timestamps, cache_mutexes,
num_elem_per_value, bucket_id);
bucket.mutex->Lock(ctx, bucket_id);
int slot_num = bucket.TryPut(ctx, key);
if (slot_num != -1) { // insert value
// if(ctx.lane_id == 0){
// printf("warp %u put key %u value %u\n",ctx.block_warp_idx, key,
// put_values[idx*num_elem_per_value]);
// }
bucket.WriteOneValue(ctx, slot_num,
&put_values[idx * num_elem_per_value]);
} else { // bucket full record missing_key + idx for Evict
if (ctx.lane_id ==
n_warp_missing) { // i-th thread keep i-th missing key
// printf("warp %u lane_id %u try put key %u fail\n",
// ctx.global_warp_idx, ctx.lane_id, key);
warp_missing_key = key;
warp_missing_index = idx;
__syncwarp();
uint32_t n_warp_missing = 0;
KeyType warp_missing_key;
uint32_t warp_missing_index = 0;
// 32 threads handle n_query keys together instead of 1 thread for 1 key
for (uint32_t i = 0; i < n_query; i++) {
uint32_t idx = offset + i;
KeyType key = blockPutKeys[ctx.block_warp_idx][i];
// ElemType* Value = &put_values[idx];
uint32_t bucket_id = blockBucketIds[ctx.block_warp_idx][i];
auto bucket = setBucketView<KeyType, ElemType>(
ctx, cache_keys, cache_values, cache_timestamps, cache_mutexes,
num_elem_per_value, bucket_id);
bucket.mutex->Lock(ctx, bucket_id);
int slot_num = bucket.TryPut(ctx, key);
if (slot_num != -1) { // insert value
// if(ctx.lane_id == 0){
// printf("warp %u put key %u value %u\n",ctx.block_warp_idx, key,
// put_values[idx*num_elem_per_value]);
// }
bucket.WriteOneValue(ctx, slot_num,
&put_values[idx * num_elem_per_value]);
} else { // bucket full record missing_key + idx for Evict
if (ctx.lane_id ==
n_warp_missing) { // i-th thread keep i-th missing key
// printf("warp %u lane_id %u try put key %u fail\n",
// ctx.global_warp_idx, ctx.lane_id, key);
warp_missing_key = key;
warp_missing_index = idx;
}
n_warp_missing += 1;
}
n_warp_missing += 1;
}
bucket.mutex->UnLock(ctx);
}
// reduce missing_key & idx
if (n_warp_missing > 0) {
uint32_t base_missing_idx = 0;
if (ctx.lane_id == 0) {
base_missing_idx = atomicAdd(n_missing, n_warp_missing);
bucket.mutex->UnLock(ctx);
}
base_missing_idx = __shfl_sync(0xFFFFFFFF, base_missing_idx, 0);
if (ctx.lane_id < n_warp_missing) {
// printf("warp %u thread %u add missing_key %u index %u in idx
// %u\n",ctx.global_warp_idx, ctx.lane_id,warp_missing_key,
// warp_missing_index,base_missing_idx + ctx.lane_id);
missing_keys[base_missing_idx + ctx.lane_id] = warp_missing_key;
missing_indices[base_missing_idx + ctx.lane_id] = warp_missing_index;
// reduce missing_key & idx
if (n_warp_missing > 0) {
uint32_t base_missing_idx = 0;
if (ctx.lane_id == 0) {
base_missing_idx = atomicAdd(n_missing, n_warp_missing);
}
base_missing_idx = __shfl_sync(warpFullMask, base_missing_idx, 0);
if (ctx.lane_id < n_warp_missing) {
// printf("warp %u thread %u add missing_key %u index %u in idx
// %u\n",ctx.global_warp_idx, ctx.lane_id,warp_missing_key,
// warp_missing_index,base_missing_idx + ctx.lane_id);
missing_keys[base_missing_idx + ctx.lane_id] = warp_missing_key;
missing_indices[base_missing_idx + ctx.lane_id] = warp_missing_index;
}
}
__syncwarp();
}
__syncwarp();
}
}
template<typename KeyType, typename ElemType>
__global__ void
EvictInternal(KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes, uint32_t nbucket,
uint32_t num_elem_per_value, ElemType *put_values,
uint32_t n_missing, KeyType *missing_keys,
uint32_t *missing_indices, uint32_t *num_evict,
KeyType *evict_keys) {
ThreadCtx ctx{};
__shared__ KeyType blockPutKeys[defaultNumWarpsPerBlock][warpsize];
__shared__ uint32_t blockBucketIds[defaultNumWarpsPerBlock][warpsize];
for (uint32_t offset = ctx.global_warp_idx * warpsize; offset < n_missing;
offset += ctx.num_warps * warpsize) {
uint32_t n_evict = min(warpsize, n_missing - offset);
if (ctx.lane_id < n_evict) {
uint32_t idx = offset + ctx.lane_id;
KeyType key = missing_keys[idx];
size_t hash = getHash(missing_keys[idx]);
uint32_t bucket_id = hash % nbucket;
blockPutKeys[ctx.block_warp_idx][ctx.lane_id] = key;
blockBucketIds[ctx.block_warp_idx][ctx.lane_id] = bucket_id;
}
__syncwarp();
for (uint32_t i = 0; i < n_evict; i++) {
uint32_t idx = offset + i;
KeyType key = blockPutKeys[ctx.block_warp_idx][i];
uint32_t bucket_id = blockBucketIds[ctx.block_warp_idx][i];
auto bucket = setBucketView<KeyType, ElemType>(
ctx, cache_keys, cache_values, cache_timestamps, cache_mutexes,
num_elem_per_value, bucket_id);
bucket.mutex->Lock(ctx, bucket_id);
KeyType evict_key = 0;
int slot_num = bucket.Evict(ctx, key, &evict_key);
// if(ctx.lane_id == 0){
// printf("evict key %u\n",evict_key);
// }
uint32_t evict_keys_idx = 0;
if (num_evict && ctx.lane_id == slot_num) {
evict_keys_idx = atomicAdd(num_evict, 1);
evict_keys[evict_keys_idx] = evict_key;
// printf("evict_keys[%u] = %u\n",evict_keys_idx,evict_key);
template<typename KeyType, typename ElemType>
__global__ void
EvictInternal(KeyType *cache_keys, ElemType *cache_values,
uint8_t *cache_timestamps, void *cache_mutexes, uint32_t nbucket,
uint32_t num_elem_per_value, ElemType *put_values,
uint32_t n_missing, KeyType *missing_keys,
uint32_t *missing_indices, uint32_t *num_evict,
KeyType *evict_keys) {
ThreadCtx ctx{};
__shared__ KeyType blockPutKeys[defaultNumWarpsPerBlock][warpsize];
__shared__ uint32_t blockBucketIds[defaultNumWarpsPerBlock][warpsize];
for (uint32_t offset = ctx.global_warp_idx * warpsize; offset < n_missing;
offset += ctx.num_warps * warpsize) {
uint32_t n_evict = min(warpsize, n_missing - offset);
if (ctx.lane_id < n_evict) {
uint32_t idx = offset + ctx.lane_id;
KeyType key = missing_keys[idx];
size_t hash = getHash(missing_keys[idx]);
uint32_t bucket_id = hash % nbucket;
blockPutKeys[ctx.block_warp_idx][ctx.lane_id] = key;
blockBucketIds[ctx.block_warp_idx][ctx.lane_id] = bucket_id;
}
__syncwarp();
for (uint32_t i = 0; i < n_evict; i++) {
uint32_t idx = offset + i;
KeyType key = blockPutKeys[ctx.block_warp_idx][i];
uint32_t bucket_id = blockBucketIds[ctx.block_warp_idx][i];
auto bucket = setBucketView<KeyType, ElemType>(
ctx, cache_keys, cache_values, cache_timestamps, cache_mutexes,
num_elem_per_value, bucket_id);
bucket.mutex->Lock(ctx, bucket_id);
KeyType evict_key = 0;
int slot_num = bucket.Evict(ctx, key, &evict_key);
uint32_t evict_keys_idx = 0;
if (num_evict && ctx.lane_id == slot_num) {
evict_keys_idx = atomicAdd(num_evict, 1);
evict_keys[evict_keys_idx] = evict_key;
}
bucket.WriteOneValue(ctx, slot_num,
put_values +
missing_indices[idx] * num_elem_per_value);
bucket.mutex->UnLock(ctx);
}
bucket.WriteOneValue(ctx, slot_num,
put_values +
missing_indices[idx] * num_elem_per_value);
// if(ctx.lane_id == 0){
// printf("warp %u put key %u value %u and evict key %u in bucket
// %u\n",ctx.block_warp_idx, key, put_values[idx*num_elem_per_value],
// evict_key, bucket_id);
// }
bucket.mutex->UnLock(ctx);
}
}
}
// TODO switch to cuda stream
template<typename KeyType, typename ElemType>
void LRUCache<KeyType, ElemType>::Put(cudaStream_t stream, uint32_t num_query,
KeyType *put_keys,
// TODO async CudaMemcpy
template<typename KeyType, typename ElemType>
void LRUCache<KeyType, ElemType>::Put(cudaStream_t stream, uint32_t num_query,
KeyType *put_keys,
ElemType *put_values, uint32_t *n_evict,
KeyType *evict_keys) {
ElemType *put_values, uint32_t *n_evict,
KeyType *evict_keys) {
assert(num_query <= maxQueryNum);
if (num_query == 0) {
return;
}
dim3 block(defaultBlockX);
dim3 grid((num_query + defaultBlockX - 1) / defaultBlockX);
uint32_t *n_missing;
CUDA_CHECK(cudaMallocManaged(&n_missing, sizeof(uint32_t)));
CUDA_CHECK(cudaMemset(n_missing, 0, sizeof(uint32_t)));
if (std::is_same<KeyType, int32_t>::value || std::is_same<KeyType, uint32_t>::value) {
PutWithoutEvictInternal<KeyType, ElemType>
<<<grid, block, uint32SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, num_query, put_keys, put_values,
n_missing, queryKeyBuffer, queryIndiceBuffer);
CUDA_CHECK(cudaStreamSynchronize(stream));
EvictInternal<KeyType, ElemType>
<<<grid, block, uint32SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, put_values, *n_missing, queryKeyBuffer,
queryIndiceBuffer, n_evict, evict_keys);
} else if (std::is_same<KeyType, int64_t>::value || std::is_same<KeyType, uint64_t>::value) {
PutWithoutEvictInternal<KeyType, ElemType>
<<<grid, block, uint64SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, num_query, put_keys, put_values,
n_missing, queryKeyBuffer, queryIndiceBuffer);
CUDA_CHECK(cudaStreamSynchronize(stream));
EvictInternal<KeyType, ElemType>
<<<grid, block, uint64SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, put_values, *n_missing, queryKeyBuffer,
queryIndiceBuffer, n_evict, evict_keys);
} else {
// std::cout << "unsupported keytype" << std::endl;
throw "Unsupported KeyType";
}
assert(num_query <= maxQueryNum);
if (num_query == 0) {
return;
}
dim3 block(defaultBlockX);
dim3 grid((num_query + defaultBlockX - 1) / defaultBlockX);
uint32_t *n_missing;
CUDA_CHECK(cudaMallocManaged(&n_missing, sizeof(uint32_t)));
CUDA_CHECK(cudaMemset(n_missing, 0, sizeof(uint32_t)));
if (std::is_same<KeyType, int32_t>::value || std::is_same<KeyType, uint32_t>::value) {
PutWithoutEvictInternal<KeyType, ElemType>
<<<grid, block, uint32SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, num_query, put_keys, put_values,
n_missing, queryKeyBuffer, queryIndiceBuffer);
CUDA_CHECK(cudaStreamSynchronize(stream));
EvictInternal<KeyType, ElemType>
<<<grid, block, uint32SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, put_values, *n_missing,
queryKeyBuffer,
queryIndiceBuffer, n_evict, evict_keys);
} else if (std::is_same<KeyType, int64_t>::value || std::is_same<KeyType, uint64_t>::value) {
PutWithoutEvictInternal<KeyType, ElemType>
<<<grid, block, uint64SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, num_query, put_keys, put_values,
n_missing, queryKeyBuffer, queryIndiceBuffer);
CUDA_CHECK(cudaStreamSynchronize(stream));
EvictInternal<KeyType, ElemType>
<<<grid, block, uint64SharedMemorySize, stream>>>(keys, values, timestamps, bucketMutexes, nbucket,
numElemPerValue, put_values, *n_missing,
queryKeyBuffer,
queryIndiceBuffer, n_evict, evict_keys);
} else {
// std::cout << "unsupported keytype" << std::endl;
throw "Unsupported KeyType";
}
// printf("n_missing = %u\n",*n_missing);
CUDA_CHECK(cudaFree(n_missing));
}
template<typename KeyType, typename ElemType>
void LRUCache<KeyType, ElemType>::Test(){
std::cout << "LRCache Test" << std::endl;
}
// TODO switch to cuda stream
template<typename KeyType, typename ElemType>
void LRUCache<KeyType, ElemType>::Get(cudaStream_t stream, uint32_t num_query,
KeyType *queries, ElemType *results,
bool *find_mask) {
// std::cout << "num_query" << num_query << "\n";
// std::cout <<"type" << type_name<KeyType>() << "\n";
assert(num_query <= maxQueryNum && "num_query should less than maxQueryNum");
if (num_query == 0) {
return;
// printf("n_missing = %u\n",*n_missing);
CUDA_CHECK(cudaFree(n_missing));
}
dim3 block(defaultBlockX);
dim3 grid((num_query + defaultBlockX - 1) / defaultBlockX);
if (std::is_same<KeyType, int32_t>::value) {
//std::cout << "lru_get_internal_int32" << "\n";
GetInternal<KeyType, ElemType><<<grid, block, uint32SharedMemorySize, stream>>>(
keys, values, timestamps, bucketMutexes, nbucket, numElemPerValue,
num_query, queries, results, find_mask);
} else if (std::is_same<KeyType, int64_t>::value) {
//std::cout << "lru_get_internal_int64" << "\n";
GetInternal<KeyType, ElemType><<<grid, block, uint64SharedMemorySize, stream>>>(
keys, values, timestamps, bucketMutexes, nbucket, numElemPerValue,
num_query, queries, results, find_mask);
} else {
//std::cout << "unsupported keytype" << std::endl;
throw "Unsupported KeyType!";
// TODO async cudaMemcpy
template<typename KeyType, typename ElemType>
void LRUCache<KeyType, ElemType>::Get(cudaStream_t stream, uint32_t num_query,
KeyType *queries, ElemType *results,
bool *find_mask) {
assert(num_query <= maxQueryNum && "num_query should less than maxQueryNum");
if (num_query == 0) {
return;
}
dim3 block(defaultBlockX);
dim3 grid((num_query + defaultBlockX - 1) / defaultBlockX);
if (std::is_same<KeyType, int32_t>::value) {
//std::cout << "lru_get_internal_int32" << "\n";
GetInternal<KeyType, ElemType><<<grid, block, uint32SharedMemorySize, stream>>>(
keys, values, timestamps, bucketMutexes, nbucket, numElemPerValue,
num_query, queries, results, find_mask);
} else if (std::is_same<KeyType, int64_t>::value) {
//std::cout << "lru_get_internal_int64" << "\n";
GetInternal<KeyType, ElemType><<<grid, block, uint64SharedMemorySize, stream>>>(
keys, values, timestamps, bucketMutexes, nbucket, numElemPerValue,
num_query, queries, results, find_mask);
} else {
//std::cout << "unsupported keytype" << std::endl;
throw "Unsupported KeyType!";
}
}
}
} // namespace lrucache
} // namespace gpucache
// LRUCacheWrapper API
namespace gpucache {
template<typename ElemType>
void NewInt32LRUCache(void **cache, CacheConfig cfg) {
auto c = new LRUCache<int32_t, ElemType>(cfg);
//std::cout << "NewInt32LRUCache LRUCache keysize " << c->keySize << std::endl;
*cache = reinterpret_cast<void *>(c);
}
template<typename ElemType>
void NewInt64LRUCache(void **cache, CacheConfig cfg) {
auto c = new LRUCache<int64_t, ElemType>(cfg);
*cache = reinterpret_cast<void *>(c);
}
LRUCacheWrapper::LRUCacheWrapper(torch::Tensor t, CacheConfig cfg) {
//std::cout << "LRUCacheWrapper constructor" << std::endl;
dtype = t.scalar_type();
cache_cfg = cfg;
key_is_int32 = cfg.capacity < INT32_MAX;
if (key_is_int32) {
kdtype = torch::kInt32;
AT_DISPATCH_ALL_TYPES(t.scalar_type(), "int32_lru_cache", [&] {
NewInt32LRUCache<scalar_t>(&lru_cache, cfg);
});
} else {
kdtype = torch::kInt64;
AT_DISPATCH_ALL_TYPES(t.scalar_type(), "int64_lru_cache", [&] {
NewInt64LRUCache<scalar_t>(&lru_cache, cfg);
});
namespace lrucache {
template<typename ElemType>
void NewInt32LRUCache(void **cache, CacheConfig cfg) {
auto c = new LRUCache<int32_t, ElemType>(cfg);
//std::cout << "NewInt32LRUCache LRUCache keysize " << c->keySize << std::endl;
*cache = reinterpret_cast<void *>(c);
}
template<typename ElemType>
void NewInt64LRUCache(void **cache, CacheConfig cfg) {
auto c = new LRUCache<int64_t, ElemType>(cfg);
*cache = reinterpret_cast<void *>(c);
}
LRUCacheWrapper::LRUCacheWrapper(torch::Tensor t, CacheConfig cfg) {
//std::cout << "LRUCacheWrapper constructor" << std::endl;
dtype = t.scalar_type();
cache_cfg = cfg;
key_is_int32 = cfg.capacity < INT32_MAX;
if (key_is_int32) {
kdtype = torch::kInt32;
AT_DISPATCH_ALL_TYPES(t.scalar_type(), "int32_lru_cache", [&] {
NewInt32LRUCache<scalar_t>(&lru_cache, cfg);
});
} else {
kdtype = torch::kInt64;
AT_DISPATCH_ALL_TYPES(t.scalar_type(), "int64_lru_cache", [&] {
NewInt64LRUCache<scalar_t>(&lru_cache, cfg);
});
}
// AT_DISPATCH_ALL_TYPES(dtype,"lru_test",[&]{
// auto c = reinterpret_cast<LRUCache<int32_t,scalar_t>*>(lru_cache);
// std::cout << "LRUCacheWrapper constructor LRUCache dim: " << c->Dim() << " keySize: " << c->KeySize() << "\n";
// });
// std::cout << "LRUCacheWrapper constructor dtype: " << dtype << " kdtype: " << kdtype << "\n";
}
LRUCacheWrapper::~LRUCacheWrapper() {
if (key_is_int32) {
AT_DISPATCH_ALL_TYPES(dtype, "int32_lru_destructor", [&] {
auto cache = reinterpret_cast<LRUCache<int32_t, scalar_t> *>(lru_cache);
delete cache;
});
} else {
AT_DISPATCH_ALL_TYPES(dtype, "int64_lru_destructor", [&] {
auto cache = reinterpret_cast<LRUCache<int64_t, scalar_t> *>(lru_cache);
delete cache;
});
// std::cout << "LRUCacheWrapper constructor dtype: " << dtype << " kdtype: " << kdtype << "\n";
}
//std::cout << "~LRUCacheWrapper()" << std::endl;
}
std::pair<torch::Tensor, torch::Tensor> LRUCacheWrapper::Get(uint32_t num_query, const torch::Tensor queries) {
assert(queries.device().index() == cache_cfg.deviceId && "query keys and cache should be in the same device");
assert(queries.size(0) == num_query && "num_query should equal queries.size(0)");
assert(num_query <= cache_cfg.maxQueryNum && "num_query must less than max_query_num");
assert(queries.dtype().toScalarType() == kdtype && "key type doesn't match");
//std::cout << "Get queries: " << queries << std::endl;
LRUCacheWrapper::~LRUCacheWrapper() {
if (key_is_int32) {
AT_DISPATCH_ALL_TYPES(dtype, "int32_lru_destructor", [&] {
auto cache = reinterpret_cast<LRUCache<int32_t, scalar_t> *>(lru_cache);
delete cache;
});
} else {
AT_DISPATCH_ALL_TYPES(dtype, "int64_lru_destructor", [&] {
auto cache = reinterpret_cast<LRUCache<int64_t, scalar_t> *>(lru_cache);
delete cache;
});
}
//std::cout << "~LRUCacheWrapper()" << std::endl;
}
std::pair<torch::Tensor, torch::Tensor> LRUCacheWrapper::Get(uint32_t num_query, const torch::Tensor queries) {
assert(queries.device().index() == cache_cfg.deviceId &&
"query keys and cache should be in the same device");
assert(queries.size(0) == num_query && "num_query should equal queries.size(0)");
assert(num_query <= cache_cfg.maxQueryNum && "num_query must less than max_query_num");
assert(queries.dtype().toScalarType() == kdtype && "key type doesn't match");
//std::cout << "Get queries: " << queries << std::endl;
// printf("hello!!!\n");
//std::cout << "Get num_query: " << num_query <<"\n";
//std::cout << "Get queries.sizes() " << queries.sizes() << "\n";
auto keys = queries.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream(cache_cfg.deviceId).stream();
auto result = torch::empty({num_query, cache_cfg.dim},
torch::dtype(dtype).device(torch::kCUDA, cache_cfg.deviceId));
auto find_mask = torch::empty({num_query}, torch::dtype(torch::kBool).device(torch::kCUDA, cache_cfg.deviceId));
if (key_is_int32) {
AT_DISPATCH_ALL_TYPES(dtype, "int32_lru_get", [&] {
LRUCache<int32_t, scalar_t> * cache = reinterpret_cast<LRUCache<int32_t, scalar_t> *>(lru_cache);
//std::cout << "Get num_query: " << num_query <<"\n";
//std::cout << "Get queries.sizes() " << queries.sizes() << "\n";
auto keys = queries.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream(cache_cfg.deviceId).stream();
auto result = torch::empty({num_query, cache_cfg.dim},
torch::dtype(dtype).device(torch::kCUDA, cache_cfg.deviceId));
auto find_mask = torch::empty({num_query},
torch::dtype(torch::kBool).device(torch::kCUDA, cache_cfg.deviceId));
if (key_is_int32) {
AT_DISPATCH_ALL_TYPES(dtype, "int32_lru_get", [&] {
LRUCache<int32_t, scalar_t> *cache = reinterpret_cast<LRUCache<int32_t, scalar_t> *>(lru_cache);
// cache->Test();
//std::cout << "keysize: " << cache->KeySize() << "\n";
//std::cout << "dim: " << cache->Dim() << "\n";
cache->Get(stream, num_query, reinterpret_cast<int32_t *>(keys),
reinterpret_cast<scalar_t *>(result.data_ptr()),
reinterpret_cast<bool *>(find_mask.data_ptr()));
});
} else {
//std::cout << "int64_lru_get" << "\n";
AT_DISPATCH_ALL_TYPES(dtype, "int64_lru_get", [&] {
LRUCache<int64_t, scalar_t>* cache = reinterpret_cast<LRUCache<int64_t, scalar_t> *>(lru_cache);
//std::cout << "cast success!" << std::endl;
cache->Get(stream, num_query, reinterpret_cast<int64_t *>(keys),
reinterpret_cast<scalar_t *>(result.data_ptr()),
reinterpret_cast<bool *>(find_mask.data_ptr()));
});
//std::cout << "keysize: " << cache->KeySize() << "\n";
//std::cout << "dim: " << cache->Dim() << "\n";
cache->Get(stream, num_query, reinterpret_cast<int32_t *>(keys),
reinterpret_cast<scalar_t *>(result.data_ptr()),
reinterpret_cast<bool *>(find_mask.data_ptr()));
});
} else {
//std::cout << "int64_lru_get" << "\n";
AT_DISPATCH_ALL_TYPES(dtype, "int64_lru_get", [&] {
LRUCache<int64_t, scalar_t> *cache = reinterpret_cast<LRUCache<int64_t, scalar_t> *>(lru_cache);
//std::cout << "cast success!" << std::endl;
cache->Get(stream, num_query, reinterpret_cast<int64_t *>(keys),
reinterpret_cast<scalar_t *>(result.data_ptr()),
reinterpret_cast<bool *>(find_mask.data_ptr()));
});
}
return std::make_pair(result, find_mask);
}
return std::make_pair(result, find_mask);
}
void LRUCacheWrapper::Put(uint32_t num_query, const torch::Tensor keys, const torch::Tensor values) {
assert(keys.device().index() == cache_cfg.deviceId && cache_cfg.deviceId == values.device().index() &&
"put keys, put values and cache should be in the same device");
assert(num_query == keys.size(0) && keys.size(0) == values.size(0) && "dim 0 of keys,values and num_query should be equal");
assert(num_query <= cache_cfg.maxQueryNum && "num_query must less than max_query_num");
assert(values.dtype() == dtype);
auto stream = at::cuda::getCurrentCUDAStream(cache_cfg.deviceId).stream();
if (key_is_int32) {
//std::cout << "int32_lru_put" << "\n";
AT_DISPATCH_ALL_TYPES(dtype, "int32_lru_put", [&] {
auto cache = reinterpret_cast<LRUCache<int32_t, scalar_t> *>(lru_cache);
//std::cout << "put_lru_dim" << cache->Dim() << std::endl;
cache->Put(stream, num_query, reinterpret_cast<int32_t *>(keys.data_ptr()),
reinterpret_cast<scalar_t *>(values.data_ptr()));
});
} else {
//std::cout << "int64_lru_put" << "\n";
AT_DISPATCH_ALL_TYPES(dtype, "int64_lru_put", [&] {
auto cache = reinterpret_cast<LRUCache<int64_t, scalar_t> *>(lru_cache);
cache->Put(stream, num_query, reinterpret_cast<int64_t *>(keys.data_ptr()),
reinterpret_cast<scalar_t *>(values.data_ptr()));
});
void LRUCacheWrapper::Put(uint32_t num_query, const torch::Tensor keys, const torch::Tensor values) {
assert(keys.device().index() == cache_cfg.deviceId && cache_cfg.deviceId == values.device().index() &&
"put keys, put values and cache should be in the same device");
assert(num_query == keys.size(0) && keys.size(0) == values.size(0) &&
"dim 0 of keys,values and num_query should be equal");
assert(num_query <= cache_cfg.maxQueryNum && "num_query must less than max_query_num");
assert(values.dtype() == dtype);
auto stream = at::cuda::getCurrentCUDAStream(cache_cfg.deviceId).stream();
if (key_is_int32) {
//std::cout << "int32_lru_put" << "\n";
AT_DISPATCH_ALL_TYPES(dtype, "int32_lru_put", [&] {
auto cache = reinterpret_cast<LRUCache<int32_t, scalar_t> *>(lru_cache);
//std::cout << "put_lru_dim" << cache->Dim() << std::endl;
cache->Put(stream, num_query, reinterpret_cast<int32_t *>(keys.data_ptr()),
reinterpret_cast<scalar_t *>(values.data_ptr()));
});
} else {
//std::cout << "int64_lru_put" << "\n";
AT_DISPATCH_ALL_TYPES(dtype, "int64_lru_put", [&] {
auto cache = reinterpret_cast<LRUCache<int64_t, scalar_t> *>(lru_cache);
cache->Put(stream, num_query, reinterpret_cast<int64_t *>(keys.data_ptr()),
reinterpret_cast<scalar_t *>(values.data_ptr()));
});
}
}
}
CacheConfig::CacheEvictStrategy LRUCacheWrapper::Strategy() {
return CacheConfig::CacheEvictStrategy::LRU;
}
CacheConfig::CacheEvictStrategy LRUCacheWrapper::Strategy() {
return CacheConfig::CacheEvictStrategy::LRU;
}
uint64_t LRUCacheWrapper::Capacity() { return cache_cfg.capacity; }
uint64_t LRUCacheWrapper::Capacity() { return cache_cfg.capacity; }
uint32_t LRUCacheWrapper::KeySize() { return cache_cfg.keySize; }
uint32_t LRUCacheWrapper::KeySize() { return cache_cfg.keySize; }
uint32_t LRUCacheWrapper::ValueSize() { return cache_cfg.valueSize; }
uint32_t LRUCacheWrapper::ValueSize() { return cache_cfg.valueSize; }
uint32_t LRUCacheWrapper::MaxQueryNum() { return cache_cfg.maxQueryNum; }
uint32_t LRUCacheWrapper::MaxQueryNum() { return cache_cfg.maxQueryNum; }
uint64_t LRUCacheWrapper::DeviceId() { return cache_cfg.deviceId; }
uint64_t LRUCacheWrapper::DeviceId() { return cache_cfg.deviceId; }
uint32_t LRUCacheWrapper::Dim() { return cache_cfg.dim; }
uint32_t LRUCacheWrapper::Dim() { return cache_cfg.dim; }
void LRUCacheWrapper::Clear() {
if (key_is_int32) {
AT_DISPATCH_ALL_TYPES(dtype, "int32_lru_clear", [&] {
auto cache = reinterpret_cast<LRUCache<int32_t, scalar_t> *>(lru_cache);
cache->Clear();
});
} else {
AT_DISPATCH_ALL_TYPES(dtype, "int64_lru_clear", [&] {
auto cache = reinterpret_cast<LRUCache<int64_t, scalar_t> *>(lru_cache);
cache->Clear();
});
void LRUCacheWrapper::Clear() {
if (key_is_int32) {
AT_DISPATCH_ALL_TYPES(dtype, "int32_lru_clear", [&] {
auto cache = reinterpret_cast<LRUCache<int32_t, scalar_t> *>(lru_cache);
cache->Clear();
});
} else {
AT_DISPATCH_ALL_TYPES(dtype, "int64_lru_clear", [&] {
auto cache = reinterpret_cast<LRUCache<int64_t, scalar_t> *>(lru_cache);
cache->Clear();
});
}
}
}
std::unique_ptr<LRUCacheWrapper> NewLRUCache(at::Tensor t, CacheConfig cfg) {
return std::make_unique<LRUCacheWrapper>(t,cfg);
}
std::unique_ptr<LRUCacheWrapper> NewLRUCache(at::Tensor t, CacheConfig cfg) {
return std::make_unique<LRUCacheWrapper>(t, cfg);
}
} // namespace lrucache
} // namespace gpucache
#pragma once
#include "common.cuh"
#include <torch/extension.h>
#include "cache.h"
namespace gpucache {
namespace lrucache {
class LRUCacheWrapper {
public:
LRUCacheWrapper(at::Tensor t, CacheConfig cfg);
~LRUCacheWrapper();
std::pair<torch::Tensor, torch::Tensor> Get(uint32_t num_query, const torch::Tensor queries);
void Put(uint32_t num_query, const torch::Tensor keys, const torch::Tensor values);
CacheConfig::CacheEvictStrategy Strategy();
uint64_t Capacity();
uint32_t KeySize();
uint32_t ValueSize();
uint32_t MaxQueryNum();
uint64_t DeviceId();
uint32_t Dim();
void Clear();
private:
void *lru_cache;
c10::ScalarType dtype; // value dtype
c10::ScalarType kdtype; // key dtype
bool key_is_int32; // only support int32 and int64
CacheConfig cache_cfg;
};
// template<typename KeyType, typename ElemType>
// class LRUCache;
//
// template<typename KeyType, typename ElemType>
// struct BucketView;
//
// struct ThreadCtx;
//
// template<typename KeyType, typename ElemType>
// __device__ __host__ BucketView<KeyType, ElemType>
// setBucketView(ThreadCtx ctx, KeyType *cache_keys, ElemType *cache_values,
// uint8_t *cache_timestamps, void *cache_mutexes,
// uint32_t num_elem_per_value, uint32_t bucket_id);
//
// template<typename KeyType, typename ElemType>
// class LRUCache : public Cache<KeyType, ElemType> {
//
// friend BucketView<KeyType, ElemType> __device__ __host__ setBucketView<KeyType, ElemType>(
// ThreadCtx ctx, KeyType *cache_keys, ElemType *cache_values,
// uint8_t *cache_timestamps, void *cache_mutexes,
// uint32_t num_elem_per_value, uint32_t bucket_id);
//
// public:
// explicit LRUCache(const CacheConfig &cfg);
//
//
// ~LRUCache();
//
// uint32_t KeySize() override;
//
// uint32_t ValueSize() override;
//
// uint64_t Capacity() override;
//
// uint32_t NumElemsPerValue() override;
//
// uint32_t MaxQueryNum();
//
// uint32_t NBucket();
//
// int8_t DeviceId() override;
//
// uint32_t Dim() override;
//
// // for test
// // void *Mutex() { return bucketMutexes; }
//
// CacheConfig::CacheEvictStrategy Strategy() override;
//
// void Clear() override;
//
// void Get(cudaStream_t stream, uint32_t num_query, KeyType *queries,
// ElemType *results, bool *find_mask) override;
//
// void Put(cudaStream_t stream, uint32_t num_query, KeyType *putkeys,
// ElemType *putvalues, uint32_t *n_evict = nullptr,
// KeyType *evict_keys = nullptr) override;
//
// private:
// KeyType *keys;
// ElemType *values;
// uint8_t *timestamps{};
// uint32_t nbucket; // 32 values for one bucket
// void *bucketMutexes{};
//
// // CacheConfig::CacheEvictStrategy strategy;
// uint64_t capacity;
// uint32_t keySize;
// uint32_t valueSize;
// uint32_t numElemPerValue; // embedding dim
// int8_t device_id;
// uint32_t dim;
//
// // store missing keys and indices for Evict
// KeyType *queryKeyBuffer{};
// uint32_t *queryIndiceBuffer{};
// uint32_t maxQueryNum;
// };
class LRUCacheWrapper {
public:
LRUCacheWrapper(at::Tensor t, CacheConfig cfg);
~LRUCacheWrapper();
std::pair<torch::Tensor, torch::Tensor> Get(uint32_t num_query, const torch::Tensor queries);
void Put(uint32_t num_query, const torch::Tensor keys, const torch::Tensor values);
CacheConfig::CacheEvictStrategy Strategy();
uint64_t Capacity();
uint32_t KeySize();
uint32_t ValueSize();
uint32_t MaxQueryNum();
uint64_t DeviceId();
uint32_t Dim();
void Clear();
// private:
void *lru_cache;
c10::ScalarType dtype;
c10::ScalarType kdtype;
bool key_is_int32;
CacheConfig cache_cfg;
};
std::unique_ptr<LRUCacheWrapper> NewLRUCache(at::Tensor t, CacheConfig cfg);
}
\ No newline at end of file
std::unique_ptr<LRUCacheWrapper> NewLRUCache(at::Tensor t, CacheConfig cfg);
} // namespace lrucache
} // namespace gpucache
\ No newline at end of file
#include "utils.cuh"
namespace gpucache {
__device__ ThreadCtx::ThreadCtx() {
auto global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
global_warp_idx = global_thread_id / warpsize;
block_warp_idx = threadIdx.x / warpsize;
lane_id = threadIdx.x % warpsize;
num_warps = blockDim.x * gridDim.x / warpsize;
}
__device__ WarpMutex::WarpMutex() : flag(0) {}
__device__ void WarpMutex::Lock(ThreadCtx &ctx, uint32_t bucket_id) {
if (ctx.lane_id == 0) {
while (atomicCAS(&flag, 0, 1) != 0) {}
}
__threadfence();
__syncwarp();
}
__device__ void WarpMutex::UnLock(ThreadCtx &ctx) {
__syncwarp();
__threadfence();
if (ctx.lane_id == 0) {
atomicExch(&flag, 0);
}
}
__global__ void initLocks(uint32_t n_bucket, void *bucketMutexes) {
uint32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (global_thread_idx < n_bucket) {
new(reinterpret_cast<WarpMutex *>(bucketMutexes) + global_thread_idx)
WarpMutex();
}
}
__global__ void checkLocks(uint32_t n_bucket, void *bucketMutexes) {
uint32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
// printf("thread %u CUDA_CHECK lock\n",global_thread_idx);
if (global_thread_idx < n_bucket) {
auto mutex =
reinterpret_cast<WarpMutex *>(bucketMutexes) + global_thread_idx;
if (mutex->flag != 0u && mutex->flag != 1u) {
printf("bucket id %u not equal 0 or 1, is %u\n", global_thread_idx,
mutex->flag);
}
}
}
}
#pragma once
#include <cuda_runtime.h>
namespace gpucache{
constexpr unsigned int warpFullMask = 0xFFFFFFFF;
constexpr unsigned int defaultBlockX = 256;
constexpr unsigned int warpsize = 32;
constexpr unsigned int defaultNumWarpsPerBlock = defaultBlockX / warpsize;
// bucket_id + key
constexpr unsigned int uint32SharedMemorySize = 2 * sizeof(uint32_t) * defaultNumWarpsPerBlock * warpsize;
constexpr unsigned int uint64SharedMemorySize =
(sizeof(uint64_t) + sizeof(uint32_t)) * defaultNumWarpsPerBlock * warpsize;
struct ThreadCtx {
__device__ ThreadCtx();
uint32_t global_warp_idx;
uint32_t block_warp_idx;
uint32_t num_warps;
uint32_t lane_id;
};
struct WarpMutex {
public:
__device__ WarpMutex();
~WarpMutex() = default;
WarpMutex(const WarpMutex &) = delete;
WarpMutex &operator=(const WarpMutex &) = delete;
WarpMutex(WarpMutex &&) = delete;
WarpMutex &operator=(WarpMutex &&) = delete;
__device__ void Lock(ThreadCtx &ctx, uint32_t bucket_id);
__device__ void UnLock(ThreadCtx &ctx);
// private:
uint32_t flag;
};
__global__ void initLocks(uint32_t n_bucket, void *bucketMutexes);
__global__ void checkLocks(uint32_t n_bucket, void *bucketMutexes);
}
......@@ -16,8 +16,9 @@
"metadata": {},
"outputs": [],
"source": [
"# args: strtegy, capacity, key_size(only support int32 and int64), value_size(dim * sizeof(elem), elem is decided by the passing tensor dtype of NewLRUCache), device_id, dim\n",
"cfg = libgpucache.CacheConfig(libgpucache.CacheConfig.LRU,65536,4,128,4096,0,32)"
"# args: strtegy, capacity, keySize(only support int32 and int64), valueSize(dim * sizeof(elem), elem is decided by the passing tensor dtype of NewLRUCache), deviceId, dim\n",
"# cfg = libgpucache.CacheConfig(libgpucache.CacheConfig.LRU,65536,4,128,4096,0,32)\n",
"cfg = libgpucache.CacheConfig(libgpucache.CacheConfig.FIFO,65536,4,128,4096,0,32)"
]
},
{
......@@ -28,7 +29,7 @@
{
"data": {
"text/plain": [
"(<CacheEvictStrategy.LRU: 1>, 65536, 4, 128, 4096, 0, 32)"
"(<CacheEvictStrategy.FIFO: 0>, 65536, 4, 128, 4096, 0, 32)"
]
},
"execution_count": 3,
......@@ -37,7 +38,7 @@
}
],
"source": [
"cfg.strategy, cfg.capacity, cfg.key_size, cfg.value_size, cfg.max_query_num, cfg.device_id, cfg.dim"
"cfg.strategy, cfg.capacity, cfg.keySize, cfg.valueSize, cfg.maxQueryNum, cfg.deviceId, cfg.dim"
]
},
{
......@@ -49,13 +50,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
"LRUCache: keySize: 4, valueSize: 128, dim: 32, capacity: 65536, maxQueryNum: 4096, deviceId: 0\n"
"FIFOCache: keySize: 4, valueSize: 128, dim: 32, capacity: 65536, maxQueryNum: 4096, deviceId: 0\n"
]
}
],
"source": [
"t = torch.empty([1],dtype=torch.float32)\n",
"cache = libgpucache.NewLRUCache(t,cfg)"
"# cache = libgpucache.NewLRUCache(t,cfg)\n",
"cache = libgpucache.NewFIFOCache(t,cfg)"
]
},
{
......@@ -88,11 +90,6 @@
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
......@@ -135,6 +132,38 @@
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,\n",
" 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28.,\n",
" 29., 30., 31., 32.],\n",
" [32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45.,\n",
" 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59.,\n",
" 60., 61., 62., 63.],\n",
" [64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77.,\n",
" 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91.,\n",
" 92., 93., 94., 95.]], device='cuda:0'),\n",
" tensor([False, True, True], device='cuda:0'),\n",
" torch.Size([3, 32]))"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"keys[0] = 999\n",
"values, find_mask = cache.Get(3,keys)\n",
"values, find_mask, values.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
......
......@@ -34,6 +34,7 @@ enable_testing()
add_executable(
cache_test
cache_test.cu
)
target_include_directories(cache_test PRIVATE /home/wtx/miniconda3/envs/dgl/include/python3.10) # path to your Python.h
......
#include "gtest/gtest.h"
#include "../src/hash/murmurhash3.cu"
#include "../src/utils.cuh"
#include "../src/utils.cu"
#include "../src/lru_cache.h"
#include "../src/lru_cache.cu"
#include "../src/fifo_cache.h"
#include "../src/fifo_cache.cu"
#include <cuda_runtime.h>
#include <unordered_set>
#include <vector>
......@@ -9,10 +14,23 @@
#include <algorithm>
namespace gpucache{
// for test
template<typename KeyType>
__global__ void CollectMissingKeysNew(uint32_t num_query, KeyType *keys,
bool *find_mask, uint32_t *n_missing,
KeyType *missing_keys);
__global__ void CollectMissingKeys(uint32_t num_query, KeyType *keys,
bool *find_mask, uint32_t *n_missing,KeyType *missing_keys) {
ThreadCtx ctx{};
for (auto offset = ctx.global_warp_idx * warpsize; offset < num_query;
offset += ctx.num_warps * warpsize) {
auto idx = offset + ctx.lane_id;
if (!find_mask[idx]) {
uint32_t base_missing_idx = atomicAdd(n_missing, 1);
missing_keys[base_missing_idx] = keys[idx];
}
}
}
void TestCache(Cache<int32_t ,uint32_t>& cache, uint32_t num_elem_per_value){
constexpr int warpsize = 32;
std::unordered_set<uint32_t> in_cache;
......@@ -114,20 +132,20 @@ namespace gpucache{
dim3 block(256);
dim3 grid((n_keys + block.x - 1)/block.x);
// CollectMissingKeys<uint32_t><<<grid,block>>>(n_keys,d_keys,d_find_mask,d_n_missing,d_missing_keys);
CollectMissingKeysNew<int32_t><<<grid,block>>>(n_keys,d_keys,d_find_mask,d_n_missing,d_missing_keys);
CollectMissingKeys<int32_t><<<grid,block>>>(n_keys,d_keys,d_find_mask,d_n_missing,d_missing_keys);
CUDA_CHECK(cudaDeviceSynchronize());
CUDA_CHECK(cudaMemcpy(n_missing,d_n_missing, sizeof(uint32_t),cudaMemcpyDefault));
CUDA_CHECK(cudaMemcpy(missing_keys,d_missing_keys,keys_size,cudaMemcpyDefault));
ASSERT_EQ(expect_n_missing, *n_missing) << "expect_n_missing is " << expect_n_missing << " n_missing is " << *n_missing;
std::unordered_set<int32_t> missing_keys_set(missing_keys, missing_keys + *n_missing);
ASSERT_EQ(missing_keys_set,expect_missing_keys_set);
// check get value
// check get value
CUDA_CHECK(cudaMemcpy(values,d_values,values_size,cudaMemcpyDefault));
for (size_t i = 0; i < n_keys ; i += 1) {
if(find_mask[i]){
ASSERT_EQ(keys[i] + 123,values[i * num_elem_per_value]) << "key[" << i << "] = " << keys[i] << " doesn't get correct value should be " << keys[i] + 123 << " get " << values[i * num_elem_per_value];
}
}
......@@ -187,30 +205,68 @@ namespace gpucache{
cfg.deviceId = 0;
cfg.dim = 8;
LRUCache<int32_t,uint32_t> cache(cfg);
lrucache::LRUCache<int32_t,uint32_t> cache(cfg);
TestCache(cache,cfg.dim);
}
TEST(GPUCACHE,FIFOCACHE){
CacheConfig cfg{};
cfg.strategy = CacheConfig::CacheEvictStrategy::FIFO;
cfg.valueSize = 32;
cfg.capacity = 4096 * 2;
cfg.keySize = 4;
cfg.maxQueryNum = 2048;
cfg.deviceId = 0;
cfg.dim = 8;
fifocache::FIFOCache<int32_t,uint32_t> cache(cfg);
TestCache(cache,cfg.dim);
}
TEST(GPUCACHE, FIFOCACHEWRAPPER){
CacheConfig cfg{CacheConfig::CacheEvictStrategy::FIFO,8192,4,32,2048,0,8};
auto t = torch::empty({1},torch::dtype(torch::kInt32).device(torch::kCUDA,0));
auto cache = fifocache::NewFIFOCache(t,cfg);
auto keys = torch::arange(0,5,torch::dtype(torch::kInt32)).to(torch::kCUDA, 0);
auto [values, find_mask] = cache->Get(5,keys);
torch::Tensor put_values = torch::reshape(torch::arange(0, 5 * 8,torch::dtype(torch::kInt32)),{5,8}).to(torch::kCUDA, 0);
std::cout << "put_values: " << put_values << std::endl;
cache->Put(5,keys,put_values);
auto result = cache->Get(5,keys);
values = result.first;
find_mask = result.second;
std::cout << " values: " << values << "find_mask: " << find_mask << std::endl;
// auto fifoc = reinterpret_cast<fifocache::FIFOCache<int32_t,int32_t>*>(cache->fifo_cache);
// std::cout <<"test keysize" << fifoc->KeySize() << std::endl;
CUDA_CHECK(cudaDeviceSynchronize());
}
TEST(GPUCACHE, LRUCACHEWRAPPER){
CacheConfig cfg{CacheConfig::CacheEvictStrategy::LRU,8192,4,32,2048,0,8};
auto t = torch::empty({1},torch::dtype(torch::kInt32).device(torch::kCUDA,0));
auto cache = NewLRUCache(t,cfg);
auto cache = lrucache::NewLRUCache(t,cfg);
auto keys = torch::arange(0,5,torch::dtype(torch::kInt32)).to(torch::kCUDA, 0);
auto [values, find_mask] = cache.Get(5,keys);
auto [values, find_mask] = cache->Get(5,keys);
torch::Tensor put_values = torch::reshape(torch::arange(0, 5 * 8,torch::dtype(torch::kInt32)),{5,8}).to(torch::kCUDA, 0);
std::cout << "put_values: " << put_values << std::endl;
cache.Put(5,keys,put_values);
cache->Put(5,keys,put_values);
auto result = cache.Get(5,keys);
auto result = cache->Get(5,keys);
values = result.first;
find_mask = result.second;
std::cout << " values: " << values << "find_mask: " << find_mask << std::endl;
auto lruc = reinterpret_cast<LRUCache<int32_t,int32_t>*>(cache.lru_cache);
std::cout <<"test keysize" << lruc->KeySize() << std::endl;
// auto lruc = reinterpret_cast<lrucache::LRUCache<int32_t,int32_t>*>(cache->lru_cache);
// std::cout <<"test keysize" << lruc->KeySize() << std::endl;
CUDA_CHECK(cudaDeviceSynchronize());
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment