其他分享
首页 > 其他分享> > 3dlut 生成和apply

3dlut 生成和apply

作者:互联网

在这里插入图片描述

def gen_3dlut(src, target, lut_dim, use_diff):
    """
    :param src: n*3 array, float, 0-1
    :param target: n*3 array, float, 0-1
    :return: 3dlut
    """
    lut = np.zeros([lut_dim*lut_dim*lut_dim, 3], dtype=np.float64)
    lut_weight = np.zeros([lut_dim*lut_dim*lut_dim, 3], dtype=np.float64)
    print('lut len:', len(lut))
    if len(src.shape) != 2 or src.shape[1] != 3:
        print("error input !")
        return -1
    #src = np.clip(src, 0, 0.999999)

    # K near
    bin_size = 1.000000001 / (lut_dim - 1)
    src_ind000 = np.floor(src / bin_size).astype(np.int32)
    src_ind100 = src_ind000 + np.array([1, 0, 0]).reshape(-1, 3)
    src_ind010 = src_ind000 + np.array([0, 1, 0]).reshape(-1, 3)
    src_ind001 = src_ind000 + np.array([0, 0, 1]).reshape(-1, 3)
    src_ind110 = src_ind000 + np.array([1, 1, 0]).reshape(-1, 3)
    src_ind101 = src_ind000 + np.array([1, 0, 1]).reshape(-1, 3)
    src_ind011 = src_ind000 + np.array([0, 1, 1]).reshape(-1, 3)
    src_ind111 = src_ind000 + np.array([1, 1, 1]).reshape(-1, 3)
    print(src_ind000.max(), src_ind100.max(), src_ind010.max(), src_ind001.max(), src_ind101.max(), src_ind111.max(), lut_dim)
    src_pos000 = src_ind000[:, 0] + src_ind000[:, 1] * lut_dim + src_ind000[:, 2] * lut_dim * lut_dim
    src_pos100 = src_ind100[:, 0] + src_ind100[:, 1] * lut_dim + src_ind100[:, 2] * lut_dim * lut_dim
    src_pos010 = src_ind010[:, 0] + src_ind010[:, 1] * lut_dim + src_ind010[:, 2] * lut_dim * lut_dim
    src_pos001 = src_ind001[:, 0] + src_ind001[:, 1] * lut_dim + src_ind001[:, 2] * lut_dim * lut_dim
    src_pos110 = src_ind110[:, 0] + src_ind110[:, 1] * lut_dim + src_ind110[:, 2] * lut_dim * lut_dim
    src_pos101 = src_ind101[:, 0] + src_ind101[:, 1] * lut_dim + src_ind101[:, 2] * lut_dim * lut_dim
    src_pos011 = src_ind011[:, 0] + src_ind011[:, 1] * lut_dim + src_ind011[:, 2] * lut_dim * lut_dim
    src_pos111 = src_ind111[:, 0] + src_ind111[:, 1] * lut_dim + src_ind111[:, 2] * lut_dim * lut_dim
    print(src_pos000.max(), src_pos001.max(), src_pos111.max(), src_pos101.max(), src_pos010.max(), src_pos011.max())
    print(src_ind100.shape, src_ind100.dtype, )
    print(lut[src_pos000].shape, (target-src).shape, lut[src_pos000][0], (target-src)[0],np.maximum(lut[src_pos000], (target-src)))
    lut[src_pos000] = np.maximum(lut[src_pos000], (target-src))
    lut[src_pos100] = np.maximum(lut[src_pos100], (target-src))
    lut[src_pos010] = np.maximum(lut[src_pos010], (target-src))
    lut[src_pos001] = np.maximum(lut[src_pos001], (target-src))
    lut[src_pos110] = np.maximum(lut[src_pos110], (target-src))
    lut[src_pos101] = np.maximum(lut[src_pos101], (target-src))
    lut[src_pos011] = np.maximum(lut[src_pos011], (target-src))
    lut[src_pos111] = np.maximum(lut[src_pos111], (target-src))
    #
    # src_weight000 = cal_invdis(src, src_ind000 * bin_size)
    # src_weight100 = cal_invdis(src, src_ind100 * bin_size)
    # src_weight010 = cal_invdis(src, src_ind010 * bin_size)
    # src_weight001 = cal_invdis(src, src_ind001 * bin_size)
    # src_weight110 = cal_invdis(src, src_ind110 * bin_size)
    # src_weight101 = cal_invdis(src, src_ind101 * bin_size)
    # src_weight011 = cal_invdis(src, src_ind011 * bin_size)
    # src_weight111 = cal_invdis(src, src_ind111 * bin_size)
    # src_weight = np.hstack((src_weight000.reshape(-1, 1),
    #                         src_weight100.reshape(-1, 1),
    #                         src_weight010.reshape(-1, 1),
    #                         src_weight001.reshape(-1, 1),
    #                         src_weight110.reshape(-1, 1),
    #                         src_weight101.reshape(-1, 1),
    #                         src_weight011.reshape(-1, 1),
    #                         src_weight111.reshape(-1, 1)))
    # src_weight = src_weight
    # print(np.hstack((src - src_ind000 * bin_size, src_weight000[..., None])), src_weight000.shape, src_weight000.shape, src_weight000.dtype, src_weight000, )
    # #use_diff = 0
    # print('dd : ', src.shape, src_ind000.reshape(-1, 3).shape, src_pos000.reshape(-1, 1).shape, src_weight.shape)
    #
    # np.set_printoptions(precision=3, suppress=True)
    # # print(tt)
    # tt1 = np.hstack((src, src_ind000.reshape(-1, 3), src_pos000.reshape(-1, 1), src_weight))
    # np.savetxt('tt1.txt', tt1, fmt='%.5f', delimiter=' ')
    # if use_diff:
    #     lut[src_pos000] += (target-src) * src_weight000[..., None]
    #     lut[src_pos100] += (target-src) * src_weight100[..., None]
    #     lut[src_pos010] += (target-src) * src_weight010[..., None]
    #     lut[src_pos001] += (target-src) * src_weight001[..., None]
    #     lut[src_pos110] += (target-src) * src_weight110[..., None]
    #     lut[src_pos101] += (target-src) * src_weight101[..., None]
    #     lut[src_pos011] += (target-src) * src_weight011[..., None]
    #     lut[src_pos111] += (target-src) * src_weight111[..., None]
    #
    # else:
    #     lut[src_pos000] += target * src_weight000[..., None]
    #     lut[src_pos100] += target * src_weight100[..., None]
    #     lut[src_pos010] += target * src_weight010[..., None]
    #     lut[src_pos001] += target * src_weight001[..., None]
    #     lut[src_pos110] += target * src_weight110[..., None]
    #     lut[src_pos101] += target * src_weight101[..., None]
    #     lut[src_pos011] += target * src_weight011[..., None]
    #     lut[src_pos111] += target * src_weight111[..., None]
    #
    # lut_weight[src_pos000] += src_weight000[..., None]
    # lut_weight[src_pos100] += src_weight100[..., None]
    # lut_weight[src_pos010] += src_weight010[..., None]
    # lut_weight[src_pos001] += src_weight001[..., None]
    # lut_weight[src_pos110] += src_weight110[..., None]
    # lut_weight[src_pos101] += src_weight101[..., None]
    # lut_weight[src_pos011] += src_weight011[..., None]
    # lut_weight[src_pos111] += src_weight111[..., None]
    # #print(lut[src_pos000].shape, lut_weight[src_pos111].shape)
    # lut[src_pos000] /= lut_weight[src_pos000]
    # lut[src_pos100] /= lut_weight[src_pos100]
    # lut[src_pos010] /= lut_weight[src_pos010]
    # lut[src_pos001] /= lut_weight[src_pos001]
    # lut[src_pos110] /= lut_weight[src_pos110]
    # lut[src_pos101] /= lut_weight[src_pos101]
    # lut[src_pos011] /= lut_weight[src_pos011]
    # lut[src_pos111] /= lut_weight[src_pos111]
    #
    # weight = np.hstack((src_weight000.reshape(-1, 1) / lut_weight[src_pos000], src_weight100.reshape(-1, 1) / lut_weight[src_pos100]))
    # print(src.shape, src_ind000.shape, src_pos000.shape, weight.shape)
    # tt = np.hstack((src, src_ind000, src_pos000.reshape(-1, 1), weight))
    # print('tt, src, ind, pos, weight :')
    #
    # np.set_printoptions(precision=3, suppress=True)
    # #print(tt)
    # np.savetxt('tt.txt', tt, fmt='%.4f', delimiter=' ')
    # fig = plt.figure()
    # ax1 = fig.add_subplot(1, 3, 1)
    # ax1.imshow(src_pos000.reshape(40, 60)[..., None])
    # ax2 = fig.add_subplot(1, 3, 2)
    # ax2.imshow(src_weight000.reshape(40, 60)[..., None])
    # ax3 = fig.add_subplot(1, 3, 3)
    # ax3.imshow((target - src).reshape(40, 60, 3))
    # plt.show()
    return lut
def cal_invdis(src, src_ind, method="inv"):
    """
    1. method = gauss:
        distance = sqrt(x*x + y*y + z*z)
        inv_dis = 1 / std / sqrt(2*pi) * exp(-(distance**2)/(2*std*std))
    2. medhod = inv:
        distance = sqrt(x*x + y*y + z*z)
        inv_dis = 1 / (distance+eps)
    :param src:
    :param src_ind:
    :param method:
    :return:
    """
    if method == 'gauss':
        std = 1
        distance2 = np.sum((src - src_ind)**2, axis=-1)
        inv_dis = 1 / (std * np.sqrt(2*np.pi)) * np.exp(-distance2/(2*std*std))
    if method == 'inv':
        distance2 = np.sum((src - src_ind) ** 2, axis=-1)
        inv_dis = 1 / (np.sqrt(distance2) + 0.0000001)

    #inv_dis2 = np.ones_like(inv_dis)
    return inv_dis

立体平滑操作,另一种更好的方法应该是分层平滑

def filter_lut(lut, lut_dim, kernel_len):
    lut = lut.reshape([lut_dim, lut_dim, lut_dim, 3])
    lut_new = np.zeros_like(lut)
    len = kernel_len
    # kernel = np.ones([len, len, len]) / (len * len * len)

    for i in range(lut_dim):
        for j in range(lut_dim):
            for k in range(lut_dim):
                v = np.zeros(3) # rgb value
                # pos = i + j * lut_dim + k * lut_dim**2
                for ii in range(len):
                    for jj in range(len):
                        for kk in range(len):
                            ii = ii - len // 2
                            jj = jj - len // 2
                            kk = kk - len // 2

                            it = i + ii
                            jt = j + jj
                            kt = k + kk
                            it = np.max(it, 0)
                            jt = np.max(jt, 0)
                            kt = np.max(kt, 0)
                            # it = np.min(it, lut_dim - 1)
                            # jt = np.min(jt, lut_dim - 1)
                            # kt = np.min(kt, lut_dim - 1)
                            if it == lut_dim:
                                it = lut_dim-1
                            if jt == lut_dim:
                                jt = lut_dim - 1
                            if kt == lut_dim:
                                kt = lut_dim-1
                            v += lut[it, jt, kt]
                lut_new[i, j, k] = v / (len ** 3)
    lut_new = lut_new.reshape([-1, 3])
    return lut_new

是否使用hsv

use_hsv = 1
    if use_hsv:
        hsv1 = colour.RGB_to_HSV(im1)
        hsv2 = colour.RGB_to_HSV(im2)
        print(hsv1.max(), hsv2.max(), hsv1.min(), hsv2.min())
        im1 = hsv1
        im2 = hsv2
    test_single_pixel = 0
    if test_single_pixel:
        im1 = np.array([202, 209, 216]*8).reshape(-1, 3) / 255
        im2 = np.array([191, 211, 235]*8).reshape(-1, 3) / 255
        print(im1, im2)
    lut_dim = 65
    lut = gen_3dlut(im1, im2, lut_dim, use_diff=1)
    if use_hsv:
        lut = colour.HSV_to_RGB(lut)
    #lut = filter_lut(lut, lut_dim, 3)

    if test_single_pixel:
        a = np.floor(im1 * (lut_dim - 1)).reshape(-1).astype(np.int32)
        print('a', a)
        lut3d = lut.reshape(lut_dim, lut_dim, lut_dim, 3)
        print(lut3d[a[2], a[1], a[0],], lut3d[a[2] + 1, a[1], a[0],], lut3d[a[2], a[1] + 1, a[0],],
              lut3d[a[2], a[1], a[0] + 1,], lut3d[a[2] + 1, a[1] + 1, a[0] + 1,])
    lut0 = get_id_lut(lut_dim)
    lut = lut + lut0
**以上只是对相邻点处理,生成lut,  利用K近邻算法可以得到更好的结果*

在这里插入代码片

def img_to_lut(input_image, target_image, lut_dim):
    # Target = target_image * 0.5 + 0.5
    # Input = input_image * 0.5 + 0.5
    Input = input_image
    Target = target_image
    safe = Target
    x = 0
    tR = []
    tG = []
    tB = []
    shape = Target.shape

    # differences on each pixel; here all pixels of the ixj array get sorted along
    # one dimension (i*j long) for both the input and the target image in the same
    # way. Afterwards for every color channel (tR, tG, TB) a list is created that
    # contains the difference between target value and input value.
    for i in range(shape[0]):
        for j in range(shape[1]):
            tR.append(Target[i, j, 0] - Input[i, j, 0])
            tG.append(Target[i, j, 1] - Input[i, j, 1])
            tB.append(Target[i, j, 2] - Input[i, j, 2])

            # here a list is created that sorts the input RGB Values along the pixel numbers
    # in the same order as above. So it starts with the first line on the left, once
    # the line is finished, counting continues on the next line left
    icolor = []
    for i in range(shape[0]):
        for j in range(shape[1]):
            icolor.append(Input[i, j, :])
    icolor = np.asarray(icolor)

    interpolation = neighbors.RadiusNeighborsRegressor(radius=10/lut_dim, weights='distance')

    LUT_size = lut_dim
    a = LUT_size - 1
    # here the input LUT is created out of equally spaced RGB values in R^3
    LUT = []
    # LUT.append(['TITLE', '"test"', ''])
    # LUT.append(['LUT_3D_SIZE', str(LUT_size), ''])
    # LUT.append(['', '', ''])
    # LUT.append(['', '', ''])
    for k in range(LUT_size):
        for j in range(LUT_size):
            for i in range(LUT_size):
                LUT.append([i / a, j / a, k / a])

    # since we have all lists in the same order (icolor and the color differences tR,
    # tG, tB) we now can fit a function between icolor and the coresponding color
    # difference at each RGB coordinate. Afterwrds we predict the difference of the
    # LUT that is on a grit at this position

    #interpolation = neighbors.KNeighborsRegressor(10, weights='distance')
    LUT = np.asarray(LUT)
    tRp = interpolation.fit(icolor, tR).predict(LUT)
    tGp = interpolation.fit(icolor, tG).predict(LUT)
    tBp = interpolation.fit(icolor, tB).predict(LUT)
    tRp[np.isnan(tRp)] = 0
    tGp[np.isnan(tGp)] = 0
    tBp[np.isnan(tBp)] = 0
    # print(tGp)
    for i in range(LUT_size ** 3):
        LUT[i, 0] = LUT[i, 0] + tRp[i]
        # print(tBp[i])
        LUT[i, 1] = LUT[i, 1] + tGp[i]
        LUT[i, 2] = LUT[i, 2] + tBp[i]

    LUT = LUT.clip(min=0)
    return LUT

https://github.com/ajcommercial/AI_color_grade_lut

apply 3dlut:这里利用三线性插值

def trilinear_forward(img, lut, lut_size=33):
    """
    :param img: h * w * channel numpy array, float
    :param lut: 3d lut, size[-1, 3]
    :param lut_size: default:33
    :return: img_lut
    """
    #img = np.clip(img, 0, 0.999999)
    bin_size = 1.0000001 / (lut_size - 1)
    dim = lut_size

    r = img[..., 0]
    g = img[..., 1]
    b = img[..., 2]

    r_id = np.floor(r / bin_size).astype(np.int32)
    g_id = np.floor(g / bin_size).astype(np.int32)
    b_id = np.floor(b / bin_size).astype(np.int32)
    # r_d = (r % bin_size) / bin_size
    # g_d = (g % bin_size) / bin_size
    # b_d = (b % bin_size) / bin_size
    r_d = np.fmod(r, bin_size) / bin_size
    g_d = np.fmod(g, bin_size) / bin_size
    b_d = np.fmod(b, bin_size) / bin_size

    id000 = r_id + g_id * dim + b_id * dim * dim
    id100 = r_id + 1 + g_id * dim + b_id * dim * dim
    id010 = r_id + (g_id + 1) * dim + b_id * dim * dim
    id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim
    id001 = r_id + g_id * dim + (b_id + 1) * dim * dim
    id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim
    id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim
    id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim

    w000 = (1 - r_d) * (1 - g_d) * (1 - b_d)
    w100 = r_d * (1 - g_d) * (1 - b_d)
    w010 = (1 - r_d) * g_d * (1 - b_d)
    w110 = r_d * g_d * (1 - b_d)
    w001 = (1 - r_d) * (1 - g_d) * b_d
    w101 = r_d * (1 - g_d) * b_d
    w011 = (1 - r_d) * g_d * b_d
    w111 = r_d * g_d * b_d

    w000 = w000[..., None]
    w100 = w100[..., None]
    w010 = w010[..., None]
    w110 = w110[..., None]
    w001 = w001[..., None]
    w101 = w101[..., None]
    w011 = w011[..., None]
    w111 = w111[..., None]
    rgb = w000 * lut[id000] + w100 * lut[id100] + \
        w010 * lut[id010] + w110 * lut[id110] + \
        w001 * lut[id001] + w101 * lut[id101] + \
        w011 * lut[id011] + w111 * lut[id111]

    return rgb

也可以使用pillow_lut python库更多的方法处理

标签:src,lut,dim,生成,...,np,apply,3dlut,size
来源: https://blog.csdn.net/tywwwww/article/details/122338432