编程语言
首页 > 编程语言> > c++runtime

c++runtime

作者:互联网

#include <iostream>
#include <string>
#include <vector>
#include <unordered_map>

namespace dl {

struct Params {
    std::unordered_map<std::string, std::string> s;
    std::unordered_map<std::string, int> i;
    std::unordered_map<std::string, float> f;
    std::unordered_map<std::string, std::vector<int>> vi;
    std::unordered_map<std::string, std::vector<float>> vf;
};

struct Tensor {
    std::string dtype;
    std::vector<int> shape;
    void *data;
    int numel;
    int device;
};

std::unordered_map<std::string, void *> _func_d;
std::unordered_map<std::string, Tensor> _tensor_d;

#define REGISTER(FUNC, NAME) _func_d[NAME] = (void *)FUNC
#define EXEC(NAME, PARAMS) \
    (*((void (*)(Params &))_func_d[NAME]))(PARAMS)
#define TYPE(DATATYPE) typeid(DATATYPE).name()
#define GLOBAL_TENSOR _tensor_d
#define HANDLE_DTYPE(T, TYPENAME, FUNC, ...) \
    if (TYPE(T) == TYPENAME) {               \
        using scalar_t = T;                  \
        FUNC<scalar_t>(__VA_ARGS__);         \
    }
#define HANDLE_DTYPE2(T1, T2, ...) \
    HANDLE_DTYPE(T1, __VA_ARGS__)  \
    HANDLE_DTYPE(T2, __VA_ARGS__)
#define HANDLE_DTYPE3(T1, T2, T3, ...) \
    HANDLE_DTYPE(T1, __VA_ARGS__)      \
    HANDLE_DTYPE(T2, __VA_ARGS__)      \
    HANDLE_DTYPE(T3, __VA_ARGS__)

template <typename scalar_t>
void empty_cpu(std::string name, std::vector<int> shape) {
    Tensor t;
    t.dtype = TYPE(scalar_t);
    t.shape = shape;
    t.numel = 1;
    for (auto d : shape) t.numel *= d;
    t.data = (void *)(new char[t.numel * sizeof(scalar_t)]);
    GLOBAL_TENSOR[name] = t;
}

void empty(Params &p) {
    HANDLE_DTYPE2(int, float, p.s["dtype"], empty_cpu, p.s["name"], p.vi["shape"])
}

template <typename scalar_t>
void linspace_cpu_(scalar_t *data, int numel) {
    for (int i = 0; i < numel; i++) data[i] = (scalar_t)i;
}

void linspace_(Params &p) {
    auto self = GLOBAL_TENSOR[p.s["name"]];
    auto numel = self.numel;
    HANDLE_DTYPE2(int, float, self.dtype, linspace_cpu_, (scalar_t *)self.data, numel)
}

} // namespace dl

int main() {
    using namespace std;
    using namespace dl;

    REGISTER(empty, "empty");
    REGISTER(linspace_, "linspace_");

    Params p;
    p.s["name"] = "Tensor0";
    p.s["dtype"] = TYPE(float);
    p.vi["shape"] = std::vector<int>{2, 2};
    EXEC("empty", p);

    Params p2;
    p2.s["name"] = "Tensor0";
    EXEC("linspace_", p2);

    Tensor &t = GLOBAL_TENSOR["Tensor0"];
    float *ptr = (float *)t.data;
    for (int i = 0; i < t.numel; i++) {
        cout << ptr[i] << endl;
    }
}

标签:std,__,HANDLE,int,void,c++,numel,runtime
来源: https://www.cnblogs.com/xytpai/p/15511436.html