kendryte-standalone-sdk/lib/nncase/v1/include/nncase/runtime/bitio.h

168 lines
4.2 KiB
C++

/* Copyright 2020 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "datatypes.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <cstring>
#include <span>
namespace nncase::runtime
{
class bitreader
{
public:
bitreader(std::span<const uint8_t> data)
: data_(data), buffer_(0), avail_(0) { }
void read(uint8_t *dest, size_t bits)
{
while (bits)
{
auto to_read = std::min(bits, size_t(8));
*dest++ = read_bits_le8(to_read);
bits -= to_read;
}
}
template <class T, size_t Bits>
T read()
{
T ret {};
read(reinterpret_cast<uint8_t *>(&ret), Bits);
return ret;
}
private:
uint8_t read_bits_le8(size_t bits)
{
assert(bits <= 8);
fill_buffer_le8(bits);
uint8_t ret = buffer_ & ((size_t(1) << bits) - 1);
buffer_ >>= bits;
avail_ -= bits;
return ret;
}
void fill_buffer_le8(size_t bits)
{
if (avail_ < bits)
{
auto max_read_bytes = std::min(data_.size() * 8, sizeof(buffer_) * 8 - avail_) / 8;
assert(max_read_bytes != 0);
uint64_t tmp = 0;
std::memcpy(&tmp, data_.data(), max_read_bytes);
data_ = data_.subspan(max_read_bytes);
buffer_ = buffer_ | (tmp << avail_);
avail_ += max_read_bytes * 8;
}
}
private:
std::span<const uint8_t> data_;
uint64_t buffer_;
size_t avail_;
};
class bitwriter
{
public:
bitwriter(std::span<uint8_t> data, size_t bitoffset = 0)
: data_(data), buffer_(0), avail_(sizeof(buffer_) * 8)
{
if (bitoffset)
{
data_ = data_.subspan(bitoffset / 8);
bitoffset %= 8;
buffer_ = data_.front() & ((size_t(1) << bitoffset) - 1);
avail_ -= bitoffset;
}
}
~bitwriter() { flush(); }
void write(const uint8_t *src, size_t bits)
{
while (bits)
{
auto to_write = std::min(bits, size_t(8));
write_bits_le8(*src++, to_write);
bits -= to_write;
}
}
template <size_t Bits, class T>
void write(T value)
{
write(reinterpret_cast<const uint8_t *>(&value), Bits);
}
void flush()
{
auto write_bytes = (buffer_written_bits() + 7) / 8;
if (write_bytes)
{
assert(data_.size() >= write_bytes);
std::memcpy(data_.data(), &buffer_, write_bytes);
data_ = data_.subspan(write_bytes);
buffer_ = 0;
avail_ = sizeof(buffer_) * 8;
}
}
private:
void write_bits_le8(uint8_t value, size_t bits)
{
assert(bits <= 8);
reserve_buffer_8();
size_t new_value = value & ((size_t(1) << bits) - 1);
buffer_ = buffer_ | (new_value << buffer_written_bits());
avail_ -= bits;
}
void reserve_buffer_8()
{
if (avail_ < 8)
{
auto write_bytes = buffer_written_bits() / 8;
assert(data_.size() >= write_bytes);
std::memcpy(data_.data(), &buffer_, write_bytes);
data_ = data_.subspan(write_bytes);
if (write_bytes == sizeof(buffer_))
buffer_ = 0;
else
buffer_ >>= write_bytes * 8;
avail_ += write_bytes * 8;
}
}
size_t buffer_written_bits() const noexcept
{
return sizeof(buffer_) * 8 - avail_;
}
private:
std::span<uint8_t> data_;
uint64_t buffer_;
size_t avail_;
};
}