Commit 8f24d345 by tianxing wang

finish lfucache

parent f081fb39
...@@ -25,7 +25,8 @@ message(STATUS "torch_include_dirs" ${TORCH_INCLUDE_DIRS}) ...@@ -25,7 +25,8 @@ message(STATUS "torch_include_dirs" ${TORCH_INCLUDE_DIRS})
find_package (Python3 COMPONENTS Interpreter Development REQUIRED) find_package (Python3 COMPONENTS Interpreter Development REQUIRED)
set(cache_lib_name gpucache) set(cache_lib_name gpucache)
add_library(${cache_lib_name} SHARED ${SOURCE_FILES}) add_library(${cache_lib_name} SHARED ${SOURCE_FILES}
src/lfu_cache.h)
set_target_properties(${cache_lib_name} PROPERTIES set_target_properties(${cache_lib_name} PROPERTIES
CUDA_SEPARABLE_COMPILATION ON CUDA_SEPARABLE_COMPILATION ON
CUDA_ARCHITECTURES "86" CUDA_ARCHITECTURES "86"
......
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream>
/* /*
* different cuda device has different atomic function, so need to adapt them * different cuda device has different atomic function, so need to adapt them
......
...@@ -49,7 +49,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ ...@@ -49,7 +49,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.def("NewLRUCache", &gpucache::lrucache::NewLRUCache, "create a lru cache",py::return_value_policy::reference); m.def("NewLRUCache", &gpucache::lrucache::NewLRUCache, "create a lru cache",py::return_value_policy::reference);
/*-----------------------------------------------------------------------------------------------------------------------------------*/ /*-----------------------------------------------------------------------------------------------------------------------------------*/
/* lrucache */ /* fifocache */
/*-----------------------------------------------------------------------------------------------------------------------------------*/ /*-----------------------------------------------------------------------------------------------------------------------------------*/
py::class_<gpucache::fifocache::FIFOCacheWrapper> fifo_cache(m, "FIFOCache"); py::class_<gpucache::fifocache::FIFOCacheWrapper> fifo_cache(m, "FIFOCache");
......
...@@ -180,6 +180,7 @@ namespace gpucache { ...@@ -180,6 +180,7 @@ namespace gpucache {
CUDA_CHECK(cudaMalloc((void **) &keys, capacity * sizeof(KeyType))); CUDA_CHECK(cudaMalloc((void **) &keys, capacity * sizeof(KeyType)));
CUDA_CHECK(cudaMalloc((void **) &values, capacity * valueSize)); CUDA_CHECK(cudaMalloc((void **) &values, capacity * valueSize));
CUDA_CHECK(cudaMalloc((void **) &timestamps, capacity * sizeof(uint8_t))); CUDA_CHECK(cudaMalloc((void **) &timestamps, capacity * sizeof(uint8_t)));
CUDA_CHECK(cudaMemset(timestamps, 0, capacity * sizeof(uint8_t)));
CUDA_CHECK(cudaMalloc((void **) &bucketMutexes, nbucket * sizeof(WarpMutex))); CUDA_CHECK(cudaMalloc((void **) &bucketMutexes, nbucket * sizeof(WarpMutex)));
dim3 block(defaultBlockX); dim3 block(defaultBlockX);
dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX); dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX);
......
#pragma once
#include <torch/extension.h>
#include "cache.h"
namespace gpucache {
namespace lfucache {
class LFUCacheWrapper {
public:
LFUCacheWrapper(at::Tensor t, CacheConfig cfg);
~LFUCacheWrapper();
std::pair<torch::Tensor, torch::Tensor> Get(uint32_t num_query, const torch::Tensor queries);
void Put(uint32_t num_query, const torch::Tensor keys, const torch::Tensor values);
CacheConfig::CacheEvictStrategy Strategy();
uint64_t Capacity();
uint32_t KeySize();
uint32_t ValueSize();
uint32_t MaxQueryNum();
uint64_t DeviceId();
uint32_t Dim();
void Clear();
private:
void *lfu_cache;
c10::ScalarType dtype; // value dtype
c10::ScalarType kdtype; // key dtype
bool key_is_int32; // only support int32 and int64
CacheConfig cache_cfg;
};
std::unique_ptr<LFUCacheWrapper> NewLFUCache(at::Tensor t, CacheConfig cfg);
} // namespace lfucache
} // namespace gpucache
\ No newline at end of file
...@@ -165,7 +165,7 @@ namespace gpucache { ...@@ -165,7 +165,7 @@ namespace gpucache {
uint32_t keySize; uint32_t keySize;
uint32_t valueSize; uint32_t valueSize;
uint32_t numElemPerValue; // embedding dim uint32_t numElemPerValue; // embedding dim
int8_t device_id; int8_t deviceId;
uint32_t dim; uint32_t dim;
// store missing keys and indices for Evict // store missing keys and indices for Evict
...@@ -178,17 +178,18 @@ namespace gpucache { ...@@ -178,17 +178,18 @@ namespace gpucache {
template<typename KeyType, typename ElemType> template<typename KeyType, typename ElemType>
LRUCache<KeyType, ElemType>::LRUCache(const CacheConfig &cfg) LRUCache<KeyType, ElemType>::LRUCache(const CacheConfig &cfg)
: keySize(cfg.keySize), valueSize(cfg.valueSize), capacity(cfg.capacity), : keySize(cfg.keySize), valueSize(cfg.valueSize), capacity(cfg.capacity),
maxQueryNum(cfg.maxQueryNum), device_id(cfg.deviceId), dim(cfg.dim) { maxQueryNum(cfg.maxQueryNum), deviceId(cfg.deviceId), dim(cfg.dim) {
// std::cout << "LRUCache constructor" << std::endl; // std::cout << "LRUCache constructor" << std::endl;
numElemPerValue = valueSize / sizeof(ElemType); numElemPerValue = valueSize / sizeof(ElemType);
nbucket = (capacity + warpsize - 1) / warpsize; nbucket = (capacity + warpsize - 1) / warpsize;
printf("LRUCache: keySize: %lu, valueSize: %u, dim: %u, capacity: %lu, " printf("LRUCache: keySize: %lu, valueSize: %u, dim: %u, capacity: %lu, "
"maxQueryNum: %u, deviceId: %u\n", "maxQueryNum: %u, deviceId: %u\n",
sizeof(KeyType), valueSize, dim, capacity, maxQueryNum, device_id); sizeof(KeyType), valueSize, dim, capacity, maxQueryNum, deviceId);
CUDA_CHECK(cudaMalloc((void **) &keys, capacity * sizeof(KeyType))); CUDA_CHECK(cudaMalloc((void **) &keys, capacity * sizeof(KeyType)));
CUDA_CHECK(cudaMalloc((void **) &values, capacity * valueSize)); CUDA_CHECK(cudaMalloc((void **) &values, capacity * valueSize));
CUDA_CHECK(cudaMalloc((void **) &timestamps, capacity * sizeof(uint8_t))); CUDA_CHECK(cudaMalloc((void **) &timestamps, capacity * sizeof(uint8_t)));
CUDA_CHECK(cudaMemset(timestamps, 0, capacity * sizeof(uint8_t)));
CUDA_CHECK(cudaMalloc((void **) &bucketMutexes, nbucket * sizeof(WarpMutex))); CUDA_CHECK(cudaMalloc((void **) &bucketMutexes, nbucket * sizeof(WarpMutex)));
dim3 block(defaultBlockX); dim3 block(defaultBlockX);
dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX); dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX);
...@@ -200,7 +201,7 @@ namespace gpucache { ...@@ -200,7 +201,7 @@ namespace gpucache {
template<typename KeyType, typename ElemType> template<typename KeyType, typename ElemType>
void LRUCache<KeyType, ElemType>::Clear() { void LRUCache<KeyType, ElemType>::Clear() {
CUDA_CHECK(cudaMemset(keys, 0, capacity * sizeof(KeyType))); // CUDA_CHECK(cudaMemset(keys, 0, capacity * sizeof(KeyType)));
CUDA_CHECK(cudaMemset(timestamps, 0, capacity * sizeof(uint8_t))); CUDA_CHECK(cudaMemset(timestamps, 0, capacity * sizeof(uint8_t)));
dim3 block(defaultBlockX); dim3 block(defaultBlockX);
dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX); dim3 grid((nbucket + defaultBlockX - 1) / defaultBlockX);
...@@ -235,7 +236,7 @@ namespace gpucache { ...@@ -235,7 +236,7 @@ namespace gpucache {
uint32_t LRUCache<KeyType, ElemType>::MaxQueryNum() { return maxQueryNum; } uint32_t LRUCache<KeyType, ElemType>::MaxQueryNum() { return maxQueryNum; }
template<typename KeyType, typename ElemType> template<typename KeyType, typename ElemType>
int8_t LRUCache<KeyType, ElemType>::DeviceId() { return device_id; } int8_t LRUCache<KeyType, ElemType>::DeviceId() { return deviceId; }
template<typename KeyType, typename ElemType> template<typename KeyType, typename ElemType>
uint32_t LRUCache<KeyType, ElemType>::Dim() { return dim; } uint32_t LRUCache<KeyType, ElemType>::Dim() { return dim; }
...@@ -282,8 +283,8 @@ namespace gpucache { ...@@ -282,8 +283,8 @@ namespace gpucache {
} }
__syncwarp(); __syncwarp();
// 32 threads compare it own slot with key // 32 threads compare it own slot with key
// if find parallel write to result // if find write to result
for (uint32_t i = 0; i < n_query; i++) { for (uint32_t i = 0; i < n_query; i++) {
uint32_t idx = offset + i; uint32_t idx = offset + i;
KeyType key = blockQueryKeys[ctx.block_warp_idx][i]; KeyType key = blockQueryKeys[ctx.block_warp_idx][i];
...@@ -434,7 +435,7 @@ namespace gpucache { ...@@ -434,7 +435,7 @@ namespace gpucache {
} }
} }
// TODO async CudaMemcpy // TODO async CudaMemcpy
template<typename KeyType, typename ElemType> template<typename KeyType, typename ElemType>
void LRUCache<KeyType, ElemType>::Put(cudaStream_t stream, uint32_t num_query, void LRUCache<KeyType, ElemType>::Put(cudaStream_t stream, uint32_t num_query,
KeyType *put_keys, KeyType *put_keys,
......
...@@ -39,18 +39,18 @@ namespace gpucache { ...@@ -39,18 +39,18 @@ namespace gpucache {
} }
} }
__global__ void checkLocks(uint32_t n_bucket, void *bucketMutexes) { // __global__ void checkLocks(uint32_t n_bucket, void *bucketMutexes) {
uint32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; // uint32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
// printf("thread %u CUDA_CHECK lock\n",global_thread_idx); // // printf("thread %u CUDA_CHECK lock\n",global_thread_idx);
if (global_thread_idx < n_bucket) { // if (global_thread_idx < n_bucket) {
auto mutex = // auto mutex =
reinterpret_cast<WarpMutex *>(bucketMutexes) + global_thread_idx; // reinterpret_cast<WarpMutex *>(bucketMutexes) + global_thread_idx;
if (mutex->flag != 0u && mutex->flag != 1u) { // if (mutex->flag != 0u && mutex->flag != 1u) {
printf("bucket id %u not equal 0 or 1, is %u\n", global_thread_idx, // printf("bucket id %u not equal 0 or 1, is %u\n", global_thread_idx,
mutex->flag); // mutex->flag);
} // }
} // }
} // }
} }
......
...@@ -45,7 +45,7 @@ namespace gpucache{ ...@@ -45,7 +45,7 @@ namespace gpucache{
__global__ void initLocks(uint32_t n_bucket, void *bucketMutexes); __global__ void initLocks(uint32_t n_bucket, void *bucketMutexes);
__global__ void checkLocks(uint32_t n_bucket, void *bucketMutexes); // __global__ void checkLocks(uint32_t n_bucket, void *bucketMutexes);
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
"source": [ "source": [
"# args: strtegy, capacity, keySize(only support int32 and int64), valueSize(dim * sizeof(elem), elem is decided by the passing tensor dtype of NewLRUCache), deviceId, dim\n", "# args: strtegy, capacity, keySize(only support int32 and int64), valueSize(dim * sizeof(elem), elem is decided by the passing tensor dtype of NewLRUCache), deviceId, dim\n",
"# cfg = libgpucache.CacheConfig(libgpucache.CacheConfig.LRU,65536,4,128,4096,0,32)\n", "# cfg = libgpucache.CacheConfig(libgpucache.CacheConfig.LRU,65536,4,128,4096,0,32)\n",
"cfg = libgpucache.CacheConfig(libgpucache.CacheConfig.FIFO,65536,4,128,4096,0,32)" "# cfg = libgpucache.CacheConfig(libgpucache.CacheConfig.FIFO,65536,4,128,4096,0,32)\n",
"cfg = libgpucache.CacheConfig(libgpucache.CacheConfig.LFU,65536,4,128,4096,0,32)"
] ]
}, },
{ {
...@@ -29,7 +30,7 @@ ...@@ -29,7 +30,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(<CacheEvictStrategy.FIFO: 0>, 65536, 4, 128, 4096, 0, 32)" "(<CacheEvictStrategy.LFU: 2>, 65536, 4, 128, 4096, 0, 32)"
] ]
}, },
"execution_count": 3, "execution_count": 3,
...@@ -132,7 +133,7 @@ ...@@ -132,7 +133,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -151,7 +152,7 @@ ...@@ -151,7 +152,7 @@
" torch.Size([3, 32]))" " torch.Size([3, 32]))"
] ]
}, },
"execution_count": 9, "execution_count": 8,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include "../src/lru_cache.cu" #include "../src/lru_cache.cu"
#include "../src/fifo_cache.h" #include "../src/fifo_cache.h"
#include "../src/fifo_cache.cu" #include "../src/fifo_cache.cu"
#include "../src/lfu_cache.h"
#include "../src/lfu_cache.cu"
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -226,6 +228,22 @@ namespace gpucache{ ...@@ -226,6 +228,22 @@ namespace gpucache{
} }
TEST(GPUCACHE,LFUCACHE){
CacheConfig cfg{};
cfg.strategy = CacheConfig::CacheEvictStrategy::LFU;
cfg.valueSize = 32;
cfg.capacity = 4096 * 2;
cfg.keySize = 4;
cfg.maxQueryNum = 2048;
cfg.deviceId = 0;
cfg.dim = 8;
lfucache::LFUCache<int32_t,uint32_t> cache(cfg);
TestCache(cache,cfg.dim);
}
TEST(GPUCACHE, FIFOCACHEWRAPPER){ TEST(GPUCACHE, FIFOCACHEWRAPPER){
CacheConfig cfg{CacheConfig::CacheEvictStrategy::FIFO,8192,4,32,2048,0,8}; CacheConfig cfg{CacheConfig::CacheEvictStrategy::FIFO,8192,4,32,2048,0,8};
auto t = torch::empty({1},torch::dtype(torch::kInt32).device(torch::kCUDA,0)); auto t = torch::empty({1},torch::dtype(torch::kInt32).device(torch::kCUDA,0));
...@@ -243,8 +261,6 @@ namespace gpucache{ ...@@ -243,8 +261,6 @@ namespace gpucache{
find_mask = result.second; find_mask = result.second;
std::cout << " values: " << values << "find_mask: " << find_mask << std::endl; std::cout << " values: " << values << "find_mask: " << find_mask << std::endl;
// auto fifoc = reinterpret_cast<fifocache::FIFOCache<int32_t,int32_t>*>(cache->fifo_cache);
// std::cout <<"test keysize" << fifoc->KeySize() << std::endl;
CUDA_CHECK(cudaDeviceSynchronize()); CUDA_CHECK(cudaDeviceSynchronize());
} }
...@@ -269,5 +285,25 @@ namespace gpucache{ ...@@ -269,5 +285,25 @@ namespace gpucache{
// std::cout <<"test keysize" << lruc->KeySize() << std::endl; // std::cout <<"test keysize" << lruc->KeySize() << std::endl;
CUDA_CHECK(cudaDeviceSynchronize()); CUDA_CHECK(cudaDeviceSynchronize());
} }
TEST(GPUCACHE, LFUCACHEWRAPPER){
CacheConfig cfg{CacheConfig::CacheEvictStrategy::LFU,8192,4,32,2048,0,8};
auto t = torch::empty({1},torch::dtype(torch::kInt32).device(torch::kCUDA,0));
auto cache = lfucache::NewLFUCache(t,cfg);
auto keys = torch::arange(0,5,torch::dtype(torch::kInt32)).to(torch::kCUDA, 0);
auto [values, find_mask] = cache->Get(5,keys);
torch::Tensor put_values = torch::reshape(torch::arange(0, 5 * 8,torch::dtype(torch::kInt32)),{5,8}).to(torch::kCUDA, 0);
std::cout << "put_values: " << put_values << std::endl;
cache->Put(5,keys,put_values);
auto result = cache->Get(5,keys);
values = result.first;
find_mask = result.second;
std::cout << " values: " << values << "find_mask: " << find_mask << std::endl;
CUDA_CHECK(cudaDeviceSynchronize());
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment