其他分享
首页 > 其他分享> > c – 浮点乘法:AVX对SSE的失速速度?

c – 浮点乘法:AVX对SSE的失速速度?

作者:互联网

我有相同的代码,但AVX版本比SSE版本低得多.有人可以解释一下吗?

我已经做过的是我尝试使用VerySleepy来分析代码,但这不能给我任何有用的结果,它只是证实它更慢……

我已经查看了SSE / AVX指南和我的CPU(Haswell)中的命令,他们需要相同的延迟/吞吐量,只需要水平添加需要AVX的附加命令…

**延迟和吞吐量**

_mm_mul_ps            -> L 5, T 0.5
_mm256_mul_ps         -> L 5, T 0.5
_mm_hadd_ps           -> L 5, T 2
_mm256_hadd_ps        -> L 5, T ?
_mm256_extractf128_ps -> L 1, T 1

代码的作用摘要:
Final1 = SUM(m_Array1 * m_Array1 * m_Array3 * m_Array3)

Final2 = SUM(m_Array2 * m_Array2 * m_Array3 * m_Array3)

Final3 = SUM(m_Array1 * m_Array2 * m_Array3 * m_Array3)

在里面

float Final1 = 0.0f;
float Final2 = 0.0f;
float Final3 = 0.0f;

float* m_Array1 = (float*)_mm_malloc( 32 * sizeof( float ), 32 );
float* m_Array2 = (float*)_mm_malloc( 32 * sizeof( float ), 32 );
float* m_Array3 = (float*)_mm_malloc( 32 * sizeof( float ), 32 );

SSE:

for ( int k = 0; k < 32; k += 4 )
{

    __m128 g1 = _mm_load_ps( m_Array1 + k );
    __m128 g2 = _mm_load_ps( m_Array2 + k );
    __m128 g3 = _mm_load_ps( m_Array3 + k );

    __m128 g1g3 = _mm_mul_ps( g1, g3 );
    __m128 g2g3 = _mm_mul_ps( g2, g3 );

    __m128 a1 = _mm_mul_ps( g1g3, g1g3 );
    __m128 a2 = _mm_mul_ps( g2g3, g2g3 );
    __m128 a3 = _mm_mul_ps( g1g3, g2g3 );

    // horizontal add
    {
        a1 = _mm_hadd_ps( a1, a1 );
        a1 = _mm_hadd_ps( a1, a1 );
        Final1 += _mm_cvtss_f32( a1 );

        a2 = _mm_hadd_ps( a2, a2 );
        a2 = _mm_hadd_ps( a2, a2 );
        Final2 += _mm_cvtss_f32( a2 );

        a3 = _mm_hadd_ps( a3, a3 );
        a3 = _mm_hadd_ps( a3, a3 );
        Final3 += _mm_cvtss_f32( a3 );

    }

}

AVX:

for ( int k = 0; k < 32; k += 8 )
{
    __m256 g1 = _mm256_load_ps( m_Array1 + k );
    __m256 g2 = _mm256_load_ps( m_Array2 + k );
    __m256 g3 = _mm256_load_ps( m_Array3 + k );

    __m256 g1g3 = _mm256_mul_ps( g1, g3 );
    __m256 g2g3 = _mm256_mul_ps( g2, g3 );

    __m256 a1 = _mm256_mul_ps( g1g3, g1g3 );
    __m256 a2 = _mm256_mul_ps( g2g3, g2g3 );
    __m256 a3 = _mm256_mul_ps( g1g3, g2g3 );

    // horizontal add1
    {
        __m256 t1 = _mm256_hadd_ps( a1, a1 );
        __m256 t2 = _mm256_hadd_ps( t1, t1 );
        __m128 t3 = _mm256_extractf128_ps( t2, 1 );
        __m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
        Final1 += _mm_cvtss_f32( t4 );
    }
    // horizontal add2
    {
        __m256 t1 = _mm256_hadd_ps( a2, a2 );
        __m256 t2 = _mm256_hadd_ps( t1, t1 );
        __m128 t3 = _mm256_extractf128_ps( t2, 1 );
        __m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
        Final2 += _mm_cvtss_f32( t4 );
    }
    // horizontal add3
    {
        __m256 t1 = _mm256_hadd_ps( a3, a3 );
        __m256 t2 = _mm256_hadd_ps( t1, t1 );
        __m128 t3 = _mm256_extractf128_ps( t2, 1 );
        __m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
        Final3 += _mm_cvtss_f32( t4 );
    }

}

解决方法:

我把你的代码放在一个测试工具中,编译它clang -O3并计时.我还实现了两个例程的更快版本,水平添加移出循环:

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <sys/time.h>   // gettimeofday
#include <immintrin.h>

static void sse(const float *m_Array1, const float *m_Array2, const float *m_Array3, size_t n, float *Final1, float *Final2, float *Final3)
{
    *Final1 = *Final2 = *Final3 = 0.0f;
    for (int k = 0; k < n; k += 4)
    {
        __m128 g1 = _mm_load_ps( m_Array1 + k );
        __m128 g2 = _mm_load_ps( m_Array2 + k );
        __m128 g3 = _mm_load_ps( m_Array3 + k );

        __m128 g1g3 = _mm_mul_ps( g1, g3 );
        __m128 g2g3 = _mm_mul_ps( g2, g3 );

        __m128 a1 = _mm_mul_ps( g1g3, g1g3 );
        __m128 a2 = _mm_mul_ps( g2g3, g2g3 );
        __m128 a3 = _mm_mul_ps( g1g3, g2g3 );

        // horizontal add
        {
            a1 = _mm_hadd_ps( a1, a1 );
            a1 = _mm_hadd_ps( a1, a1 );
            *Final1 += _mm_cvtss_f32( a1 );

            a2 = _mm_hadd_ps( a2, a2 );
            a2 = _mm_hadd_ps( a2, a2 );
            *Final2 += _mm_cvtss_f32( a2 );

            a3 = _mm_hadd_ps( a3, a3 );
            a3 = _mm_hadd_ps( a3, a3 );
            *Final3 += _mm_cvtss_f32( a3 );
        }
    }
}

static void sse_fast(const float *m_Array1, const float *m_Array2, const float *m_Array3, size_t n, float *Final1, float *Final2, float *Final3)
{
    *Final1 = *Final2 = *Final3 = 0.0f;
    __m128 a1 = _mm_setzero_ps();
    __m128 a2 = _mm_setzero_ps();
    __m128 a3 = _mm_setzero_ps();
    for (int k = 0; k < n; k += 4)
    {
        __m128 g1 = _mm_load_ps( m_Array1 + k );
        __m128 g2 = _mm_load_ps( m_Array2 + k );
        __m128 g3 = _mm_load_ps( m_Array3 + k );

        __m128 g1g3 = _mm_mul_ps( g1, g3 );
        __m128 g2g3 = _mm_mul_ps( g2, g3 );

        a1 = _mm_add_ps(a1, _mm_mul_ps( g1g3, g1g3 ));
        a2 = _mm_add_ps(a2, _mm_mul_ps( g2g3, g2g3 ));
        a3 = _mm_add_ps(a3, _mm_mul_ps( g1g3, g2g3 ));
    }
    // horizontal add
    a1 = _mm_hadd_ps( a1, a1 );
    a1 = _mm_hadd_ps( a1, a1 );
    *Final1 += _mm_cvtss_f32( a1 );

    a2 = _mm_hadd_ps( a2, a2 );
    a2 = _mm_hadd_ps( a2, a2 );
    *Final2 += _mm_cvtss_f32( a2 );

    a3 = _mm_hadd_ps( a3, a3 );
    a3 = _mm_hadd_ps( a3, a3 );
    *Final3 += _mm_cvtss_f32( a3 );
}

static void avx(const float *m_Array1, const float *m_Array2, const float *m_Array3, size_t n, float *Final1, float *Final2, float *Final3)
{
    *Final1 = *Final2 = *Final3 = 0.0f;
    for (int k = 0; k < n; k += 8 )
    {
        __m256 g1 = _mm256_load_ps( m_Array1 + k );
        __m256 g2 = _mm256_load_ps( m_Array2 + k );
        __m256 g3 = _mm256_load_ps( m_Array3 + k );

        __m256 g1g3 = _mm256_mul_ps( g1, g3 );
        __m256 g2g3 = _mm256_mul_ps( g2, g3 );

        __m256 a1 = _mm256_mul_ps( g1g3, g1g3 );
        __m256 a2 = _mm256_mul_ps( g2g3, g2g3 );
        __m256 a3 = _mm256_mul_ps( g1g3, g2g3 );

        // horizontal add1
        {
            __m256 t1 = _mm256_hadd_ps( a1, a1 );
            __m256 t2 = _mm256_hadd_ps( t1, t1 );
            __m128 t3 = _mm256_extractf128_ps( t2, 1 );
            __m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
            *Final1 += _mm_cvtss_f32( t4 );
        }
        // horizontal add2
        {
            __m256 t1 = _mm256_hadd_ps( a2, a2 );
            __m256 t2 = _mm256_hadd_ps( t1, t1 );
            __m128 t3 = _mm256_extractf128_ps( t2, 1 );
            __m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
            *Final2 += _mm_cvtss_f32( t4 );
        }
        // horizontal add3
        {
            __m256 t1 = _mm256_hadd_ps( a3, a3 );
            __m256 t2 = _mm256_hadd_ps( t1, t1 );
            __m128 t3 = _mm256_extractf128_ps( t2, 1 );
            __m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
            *Final3 += _mm_cvtss_f32( t4 );
        }
    }
}

static void avx_fast(const float *m_Array1, const float *m_Array2, const float *m_Array3, size_t n, float *Final1, float *Final2, float *Final3)
{
    *Final1 = *Final2 = *Final3 = 0.0f;
    __m256 a1 = _mm256_setzero_ps();
    __m256 a2 = _mm256_setzero_ps();
    __m256 a3 = _mm256_setzero_ps();
    for (int k = 0; k < n; k += 8 )
    {
        __m256 g1 = _mm256_load_ps( m_Array1 + k );
        __m256 g2 = _mm256_load_ps( m_Array2 + k );
        __m256 g3 = _mm256_load_ps( m_Array3 + k );

        __m256 g1g3 = _mm256_mul_ps( g1, g3 );
        __m256 g2g3 = _mm256_mul_ps( g2, g3 );

        a1 = _mm256_add_ps(a1, _mm256_mul_ps( g1g3, g1g3 ));
        a2 = _mm256_add_ps(a2, _mm256_mul_ps( g2g3, g2g3 ));
        a3 = _mm256_add_ps(a3, _mm256_mul_ps( g1g3, g2g3 ));
    }

    // horizontal add1

    {
        __m256 t1 = _mm256_hadd_ps( a1, a1 );
        __m256 t2 = _mm256_hadd_ps( t1, t1 );
        __m128 t3 = _mm256_extractf128_ps( t2, 1 );
        __m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
        *Final1 += _mm_cvtss_f32( t4 );
    }

    // horizontal add2

    {
        __m256 t1 = _mm256_hadd_ps( a2, a2 );
        __m256 t2 = _mm256_hadd_ps( t1, t1 );
        __m128 t3 = _mm256_extractf128_ps( t2, 1 );
        __m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
        *Final2 += _mm_cvtss_f32( t4 );
    }

    // horizontal add3

    {
        __m256 t1 = _mm256_hadd_ps( a3, a3 );
        __m256 t2 = _mm256_hadd_ps( t1, t1 );
        __m128 t3 = _mm256_extractf128_ps( t2, 1 );
        __m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
        *Final3 += _mm_cvtss_f32( t4 );
    }

}

int main(int argc, char *argv[])
{
    size_t n = 4096;

    if (argc > 1) n = atoi(argv[1]);

    float *in_1 = valloc(n * sizeof(in_1[0]));
    float *in_2 = valloc(n * sizeof(in_2[0]));
    float *in_3 = valloc(n * sizeof(in_3[0]));
    float out_1, out_2, out_3;

    struct timeval t0, t1;
    double t_ms;

    for (int i = 0; i < n; ++i)
    {
        in_1[i] = (float)rand() / (float)(RAND_MAX / 2);
        in_2[i] = (float)rand() / (float)(RAND_MAX / 2);
        in_3[i] = (float)rand() / (float)(RAND_MAX / 2);
    }

    sse(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
    printf("sse     : %g, %g, %g\n", out_1, out_2, out_3);
    sse_fast(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
    printf("sse_fast: %g, %g, %g\n", out_1, out_2, out_3);
    avx(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
    printf("avx     : %g, %g, %g\n", out_1, out_2, out_3);
    avx_fast(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
    printf("avx_fast: %g, %g, %g\n", out_1, out_2, out_3);

    gettimeofday(&t0, NULL);
    for (int k = 0; k < 100; ++k) sse(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
    gettimeofday(&t1, NULL);
    t_ms = ((double)(t1.tv_sec - t0.tv_sec) + (double)(t1.tv_usec - t0.tv_usec) * 1.0e-6) * 1.0e3;
    printf("sse     : %g, %g, %g, %g ms\n", out_1, out_2, out_3, t_ms);

    gettimeofday(&t0, NULL);
    for (int k = 0; k < 100; ++k) sse_fast(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
    gettimeofday(&t1, NULL);
    t_ms = ((double)(t1.tv_sec - t0.tv_sec) + (double)(t1.tv_usec - t0.tv_usec) * 1.0e-6) * 1.0e3;
    printf("sse_fast: %g, %g, %g, %g ms\n", out_1, out_2, out_3, t_ms);

    gettimeofday(&t0, NULL);
    for (int k = 0; k < 100; ++k) avx(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
    gettimeofday(&t1, NULL);
    t_ms = ((double)(t1.tv_sec - t0.tv_sec) + (double)(t1.tv_usec - t0.tv_usec) * 1.0e-6) * 1.0e3;
    printf("avx     : %g, %g, %g, %g ms\n", out_1, out_2, out_3, t_ms);

    gettimeofday(&t0, NULL);
    for (int k = 0; k < 100; ++k) avx_fast(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
    gettimeofday(&t1, NULL);
    t_ms = ((double)(t1.tv_sec - t0.tv_sec) + (double)(t1.tv_usec - t0.tv_usec) * 1.0e-6) * 1.0e3;
    printf("avx_fast: %g, %g, %g, %g ms\n", out_1, out_2, out_3, t_ms);

    return 0;
}

我的2.6 GHz Haswell(MacBook Pro)的结果如下:

sse     : 0.439 ms
sse_fast: 0.153 ms
avx     : 0.309 ms
avx_fast: 0.085 ms

因此,对于原始实现和优化的实现,AVX版本确实看起来比SSE版本更快.优化的实现速度明显快于原始版本,但幅度更大.

我只能猜测你的编译器是不是为AVX生成了非常好的代码(或者你忘了启用编译器优化?),或者可能对你的基准测试方法有疑问.

标签:c,performance,sse,avx
来源: https://codeday.me/bug/20190830/1765187.html