kendryte-standalone-sdk/lib/nncase/v1/include/nncase/runtime/bfloat16.h

354 lines
9.6 KiB
C++

/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cmath>
#include <cstdint>
#include <float.h>
#include <functional>
#include <limits>
#include <nncase/runtime/compiler_defs.h>
namespace nncase
{
struct half
{
uint16_t value;
};
struct from_raw_t
{
explicit from_raw_t() = default;
};
NNCASE_INLINE_VAR constexpr from_raw_t from_raw {};
struct bfloat16
{
private:
union fp32
{
uint32_t u32;
float f32;
uint16_t u16() const noexcept
{
constexpr size_t index = NNCASE_LITTLE_ENDIAN ? 1 : 0;
return reinterpret_cast<const uint16_t *>(&u32)[index];
}
uint16_t &u16() noexcept
{
constexpr size_t index = NNCASE_LITTLE_ENDIAN ? 1 : 0;
return reinterpret_cast<uint16_t *>(&u32)[index];
}
};
// A value that represents "zero".
static constexpr uint16_t ZERO_VALUE = 0;
// A value that represents "not a number".
static constexpr uint16_t NAN_VALUE = 0x7FC0;
public:
bfloat16() noexcept = default;
explicit bfloat16(float v) noexcept
: value_(truncate_to_bfloat16(v).value_) { }
template <class T, class = std::enable_if_t<std::is_integral<T>::value || std::is_floating_point<T>::value>>
explicit bfloat16(const T &val) noexcept
: bfloat16(static_cast<float>(val)) { }
constexpr bfloat16(from_raw_t, uint16_t value) noexcept
: value_(value) { }
operator float() const noexcept
{
fp32 result;
result.u32 = 0;
result.u16() = value_;
return result.f32;
}
const uint16_t &raw() const noexcept { return value_; }
uint16_t &raw() noexcept { return value_; }
static constexpr bfloat16 from_raw(uint16_t v) noexcept
{
return bfloat16(nncase::from_raw, v);
}
static bfloat16 truncate_to_bfloat16(const float v) noexcept
{
bfloat16 output;
if (!std::isnan(v))
{
fp32 f;
f.f32 = v;
output.value_ = f.u16();
}
else
{
output.value_ = NAN_VALUE;
}
return output;
}
// Converts a float point to bfloat16, with round-nearest-to-even as rounding
// method.
static bfloat16 round_to_bfloat16(float v)
{
uint32_t input;
fp32 f;
f.f32 = v;
input = f.u32;
bfloat16 output;
if (!std::isnan(v))
{
// Least significant bit of resulting bfloat.
uint32_t lsb = (input >> 16) & 1;
uint32_t rounding_bias = 0x7fff + lsb;
input += rounding_bias;
output.value_ = static_cast<uint16_t>(input >> 16);
}
else
{
// If the value is a NaN, squash it to a qNaN with msb of fraction set,
// this makes sure after truncation we don't end up with an inf.
//
// qNaN magic: All exponent bits set + most significant bit of fraction
// set.
output.value_ = NAN_VALUE;
}
return output;
}
static constexpr bfloat16 epsilon() noexcept
{
// 0x1.0p-7
return from_raw(0x3c00);
}
static constexpr bfloat16 highest() noexcept
{
// 0x1.FEp127
return from_raw(0x7F7F);
}
static constexpr bfloat16 min() noexcept
{
// 0x1p-126
return from_raw(0x0080);
}
static constexpr bfloat16 lowest() noexcept
{
// -0x1.FEp127
return from_raw(0xFF7F);
}
static constexpr bfloat16 nan() noexcept
{
return from_raw(NAN_VALUE);
}
static constexpr bfloat16 infinity() noexcept
{
return from_raw(0x7f80);
}
constexpr bool zero() const noexcept { return (value_ & 0x7FFF) == ZERO_VALUE; }
private:
uint16_t value_;
};
#define DEFINE_BF16_BINARY_BF16RET(x) \
inline bfloat16 operator x(bfloat16 a, bfloat16 b) noexcept \
{ \
return bfloat16::round_to_bfloat16(float(a) x float(b)); \
}
#define DEFINE_BF16_BINARY_BOOLRET(x) \
inline bool operator x(bfloat16 a, bfloat16 b) noexcept \
{ \
return float(a) x float(b); \
}
DEFINE_BF16_BINARY_BF16RET(+)
DEFINE_BF16_BINARY_BF16RET(-)
DEFINE_BF16_BINARY_BF16RET(*)
DEFINE_BF16_BINARY_BF16RET(/)
DEFINE_BF16_BINARY_BOOLRET(<)
DEFINE_BF16_BINARY_BOOLRET(<=)
DEFINE_BF16_BINARY_BOOLRET(>=)
DEFINE_BF16_BINARY_BOOLRET(>)
#define DEFINE_BF16_BINARY_SELF_MOD(x, op) \
inline bfloat16 &operator x(bfloat16 &a, bfloat16 b) noexcept \
{ \
a = a op b; \
return a; \
}
DEFINE_BF16_BINARY_SELF_MOD(+=, +)
DEFINE_BF16_BINARY_SELF_MOD(-=, -)
DEFINE_BF16_BINARY_SELF_MOD(*=, *)
DEFINE_BF16_BINARY_SELF_MOD(/=, /)
inline bfloat16 operator-(bfloat16 a) noexcept
{
return bfloat16::round_to_bfloat16(-float(a));
}
inline bool operator==(const bfloat16 &lhs, const bfloat16 &rhs) noexcept
{
return lhs.raw() == rhs.raw();
}
inline bool operator!=(const bfloat16 &lhs, const bfloat16 &rhs) noexcept
{
return lhs.raw() != rhs.raw();
}
}
namespace std
{
template <>
struct hash<nncase::bfloat16>
{
size_t operator()(const nncase::bfloat16 &v) const
{
return hash<float>()(static_cast<float>(v));
}
};
template <>
struct numeric_limits<nncase::bfloat16>
{
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr bool is_bounded = true;
static constexpr bool is_iec559 = true;
static constexpr bool is_signed = true;
static constexpr bool is_specialized = true;
static constexpr float_round_style round_style = round_to_nearest;
static constexpr int radix = FLT_RADIX;
NNCASE_UNUSED static constexpr nncase::bfloat16(min)() noexcept
{
return nncase::bfloat16::min();
}
NNCASE_UNUSED static constexpr nncase::bfloat16(max)() noexcept
{
return nncase::bfloat16::highest();
}
NNCASE_UNUSED static constexpr nncase::bfloat16 lowest() noexcept
{
return nncase::bfloat16::lowest();
}
NNCASE_UNUSED static constexpr nncase::bfloat16 epsilon() noexcept
{
return nncase::bfloat16::epsilon();
}
NNCASE_UNUSED static constexpr nncase::bfloat16 round_error() noexcept
{
// 0.5
return nncase::bfloat16::from_raw(0x3f00);
}
NNCASE_UNUSED static constexpr nncase::bfloat16 denorm_min() noexcept
{
return nncase::bfloat16::min();
}
NNCASE_UNUSED static constexpr nncase::bfloat16 infinity() noexcept
{
return nncase::bfloat16::infinity();
}
NNCASE_UNUSED static constexpr nncase::bfloat16 quiet_NaN() noexcept
{
return nncase::bfloat16::nan();
}
NNCASE_UNUSED static constexpr nncase::bfloat16 signaling_NaN() noexcept
{
return nncase::bfloat16::nan();
}
static constexpr int digits = 8;
static constexpr int max_exponent = FLT_MAX_EXP;
static constexpr int min_exponent = FLT_MIN_EXP;
};
using nncase::bfloat16;
inline bool isinf(const bfloat16 &a) { return std::isinf(float(a)); }
inline bool isnan(const bfloat16 &a) { return std::isnan(float(a)); }
inline bool isfinite(const bfloat16 &a) { return std::isfinite(float(a)); }
inline bfloat16 abs(const bfloat16 &a) { return bfloat16::round_to_bfloat16(fabsf(float(a))); }
inline bfloat16 exp(const bfloat16 &a) { return bfloat16::round_to_bfloat16(expf(float(a))); }
inline bfloat16 log(const bfloat16 &a) { return bfloat16::round_to_bfloat16(logf(float(a))); }
inline bfloat16 log10(const bfloat16 &a)
{
return bfloat16::round_to_bfloat16(log10f(float(a)));
}
inline bfloat16 sqrt(const bfloat16 &a)
{
return bfloat16::round_to_bfloat16(sqrtf(float(a)));
}
inline bfloat16 pow(const bfloat16 &a, const bfloat16 &b)
{
return bfloat16::round_to_bfloat16(powf(float(a), float(b)));
}
inline bfloat16 sin(const bfloat16 &a) { return bfloat16::round_to_bfloat16(sinf(float(a))); }
inline bfloat16 cos(const bfloat16 &a) { return bfloat16::round_to_bfloat16(cosf(float(a))); }
inline bfloat16 tan(const bfloat16 &a) { return bfloat16::round_to_bfloat16(tanf(float(a))); }
inline bfloat16 tanh(const bfloat16 &a)
{
return bfloat16::round_to_bfloat16(tanhf(float(a)));
}
inline bfloat16 floor(const bfloat16 &a)
{
return bfloat16::round_to_bfloat16(floorf(float(a)));
}
inline bfloat16 ceil(const bfloat16 &a)
{
return bfloat16::round_to_bfloat16(ceilf(float(a)));
}
inline bfloat16 round(const bfloat16 &a)
{
return bfloat16::round_to_bfloat16(roundf(float(a)));
}
inline bfloat16 nearbyint(const bfloat16 &a)
{
return bfloat16::round_to_bfloat16(nearbyintf(float(a)));
}
inline long lrint(const bfloat16 &a)
{
return lrintf(float(a));
}
} // namespace std