initial commit, 4.5 stable
Some checks failed
🔗 GHA / 📊 Static checks (push) Has been cancelled
🔗 GHA / 🤖 Android (push) Has been cancelled
🔗 GHA / 🍏 iOS (push) Has been cancelled
🔗 GHA / 🐧 Linux (push) Has been cancelled
🔗 GHA / 🍎 macOS (push) Has been cancelled
🔗 GHA / 🏁 Windows (push) Has been cancelled
🔗 GHA / 🌐 Web (push) Has been cancelled
Some checks failed
🔗 GHA / 📊 Static checks (push) Has been cancelled
🔗 GHA / 🤖 Android (push) Has been cancelled
🔗 GHA / 🍏 iOS (push) Has been cancelled
🔗 GHA / 🐧 Linux (push) Has been cancelled
🔗 GHA / 🍎 macOS (push) Has been cancelled
🔗 GHA / 🏁 Windows (push) Has been cancelled
🔗 GHA / 🌐 Web (push) Has been cancelled
This commit is contained in:
856
thirdparty/icu4c/common/lstmbe.cpp
vendored
Normal file
856
thirdparty/icu4c/common/lstmbe.cpp
vendored
Normal file
@@ -0,0 +1,856 @@
|
||||
// © 2021 and later: Unicode, Inc. and others.
|
||||
// License & terms of use: http://www.unicode.org/copyright.html
|
||||
|
||||
#include <complex>
|
||||
#include <utility>
|
||||
|
||||
#include "unicode/utypes.h"
|
||||
|
||||
#if !UCONFIG_NO_BREAK_ITERATION
|
||||
|
||||
#include "brkeng.h"
|
||||
#include "charstr.h"
|
||||
#include "cmemory.h"
|
||||
#include "lstmbe.h"
|
||||
#include "putilimp.h"
|
||||
#include "uassert.h"
|
||||
#include "ubrkimpl.h"
|
||||
#include "uresimp.h"
|
||||
#include "uvectr32.h"
|
||||
#include "uvector.h"
|
||||
|
||||
#include "unicode/brkiter.h"
|
||||
#include "unicode/resbund.h"
|
||||
#include "unicode/ubrk.h"
|
||||
#include "unicode/uniset.h"
|
||||
#include "unicode/ustring.h"
|
||||
#include "unicode/utf.h"
|
||||
|
||||
U_NAMESPACE_BEGIN
|
||||
|
||||
// Uncomment the following #define to debug.
|
||||
// #define LSTM_DEBUG 1
|
||||
// #define LSTM_VECTORIZER_DEBUG 1
|
||||
|
||||
/**
|
||||
* Interface for reading 1D array.
|
||||
*/
|
||||
class ReadArray1D {
|
||||
public:
|
||||
virtual ~ReadArray1D();
|
||||
virtual int32_t d1() const = 0;
|
||||
virtual float get(int32_t i) const = 0;
|
||||
|
||||
#ifdef LSTM_DEBUG
|
||||
void print() const {
|
||||
printf("\n[");
|
||||
for (int32_t i = 0; i < d1(); i++) {
|
||||
printf("%0.8e ", get(i));
|
||||
if (i % 4 == 3) printf("\n");
|
||||
}
|
||||
printf("]\n");
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
ReadArray1D::~ReadArray1D()
|
||||
{
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for reading 2D array.
|
||||
*/
|
||||
class ReadArray2D {
|
||||
public:
|
||||
virtual ~ReadArray2D();
|
||||
virtual int32_t d1() const = 0;
|
||||
virtual int32_t d2() const = 0;
|
||||
virtual float get(int32_t i, int32_t j) const = 0;
|
||||
};
|
||||
|
||||
ReadArray2D::~ReadArray2D()
|
||||
{
|
||||
}
|
||||
|
||||
/**
|
||||
* A class to index a float array as a 1D Array without owning the pointer or
|
||||
* copy the data.
|
||||
*/
|
||||
class ConstArray1D : public ReadArray1D {
|
||||
public:
|
||||
ConstArray1D() : data_(nullptr), d1_(0) {}
|
||||
|
||||
ConstArray1D(const float* data, int32_t d1) : data_(data), d1_(d1) {}
|
||||
|
||||
virtual ~ConstArray1D();
|
||||
|
||||
// Init the object, the object does not own the data nor copy.
|
||||
// It is designed to directly use data from memory mapped resources.
|
||||
void init(const int32_t* data, int32_t d1) {
|
||||
U_ASSERT(IEEE_754 == 1);
|
||||
data_ = reinterpret_cast<const float*>(data);
|
||||
d1_ = d1;
|
||||
}
|
||||
|
||||
// ReadArray1D methods.
|
||||
virtual int32_t d1() const override { return d1_; }
|
||||
virtual float get(int32_t i) const override {
|
||||
U_ASSERT(i < d1_);
|
||||
return data_[i];
|
||||
}
|
||||
|
||||
private:
|
||||
const float* data_;
|
||||
int32_t d1_;
|
||||
};
|
||||
|
||||
ConstArray1D::~ConstArray1D()
|
||||
{
|
||||
}
|
||||
|
||||
/**
|
||||
* A class to index a float array as a 2D Array without owning the pointer or
|
||||
* copy the data.
|
||||
*/
|
||||
class ConstArray2D : public ReadArray2D {
|
||||
public:
|
||||
ConstArray2D() : data_(nullptr), d1_(0), d2_(0) {}
|
||||
|
||||
ConstArray2D(const float* data, int32_t d1, int32_t d2)
|
||||
: data_(data), d1_(d1), d2_(d2) {}
|
||||
|
||||
virtual ~ConstArray2D();
|
||||
|
||||
// Init the object, the object does not own the data nor copy.
|
||||
// It is designed to directly use data from memory mapped resources.
|
||||
void init(const int32_t* data, int32_t d1, int32_t d2) {
|
||||
U_ASSERT(IEEE_754 == 1);
|
||||
data_ = reinterpret_cast<const float*>(data);
|
||||
d1_ = d1;
|
||||
d2_ = d2;
|
||||
}
|
||||
|
||||
// ReadArray2D methods.
|
||||
inline int32_t d1() const override { return d1_; }
|
||||
inline int32_t d2() const override { return d2_; }
|
||||
float get(int32_t i, int32_t j) const override {
|
||||
U_ASSERT(i < d1_);
|
||||
U_ASSERT(j < d2_);
|
||||
return data_[i * d2_ + j];
|
||||
}
|
||||
|
||||
// Expose the ith row as a ConstArray1D
|
||||
inline ConstArray1D row(int32_t i) const {
|
||||
U_ASSERT(i < d1_);
|
||||
return ConstArray1D(data_ + i * d2_, d2_);
|
||||
}
|
||||
|
||||
private:
|
||||
const float* data_;
|
||||
int32_t d1_;
|
||||
int32_t d2_;
|
||||
};
|
||||
|
||||
ConstArray2D::~ConstArray2D()
|
||||
{
|
||||
}
|
||||
|
||||
/**
|
||||
* A class to allocate data as a writable 1D array.
|
||||
* This is the main class implement matrix operation.
|
||||
*/
|
||||
class Array1D : public ReadArray1D {
|
||||
public:
|
||||
Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {}
|
||||
Array1D(int32_t d1, UErrorCode &status)
|
||||
: memory_(uprv_malloc(d1 * sizeof(float))),
|
||||
data_(static_cast<float*>(memory_)), d1_(d1) {
|
||||
if (U_SUCCESS(status)) {
|
||||
if (memory_ == nullptr) {
|
||||
status = U_MEMORY_ALLOCATION_ERROR;
|
||||
return;
|
||||
}
|
||||
clear();
|
||||
}
|
||||
}
|
||||
|
||||
virtual ~Array1D();
|
||||
|
||||
// A special constructor which does not own the memory but writeable
|
||||
// as a slice of an array.
|
||||
Array1D(float* data, int32_t d1)
|
||||
: memory_(nullptr), data_(data), d1_(d1) {}
|
||||
|
||||
// ReadArray1D methods.
|
||||
virtual int32_t d1() const override { return d1_; }
|
||||
virtual float get(int32_t i) const override {
|
||||
U_ASSERT(i < d1_);
|
||||
return data_[i];
|
||||
}
|
||||
|
||||
// Return the index which point to the max data in the array.
|
||||
inline int32_t maxIndex() const {
|
||||
int32_t index = 0;
|
||||
float max = data_[0];
|
||||
for (int32_t i = 1; i < d1_; i++) {
|
||||
if (data_[i] > max) {
|
||||
max = data_[i];
|
||||
index = i;
|
||||
}
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
// Slice part of the array to a new one.
|
||||
inline Array1D slice(int32_t from, int32_t size) const {
|
||||
U_ASSERT(from >= 0);
|
||||
U_ASSERT(from < d1_);
|
||||
U_ASSERT(from + size <= d1_);
|
||||
return Array1D(data_ + from, size);
|
||||
}
|
||||
|
||||
// Add dot product of a 1D array and a 2D array into this one.
|
||||
inline Array1D& addDotProduct(const ReadArray1D& a, const ReadArray2D& b) {
|
||||
U_ASSERT(a.d1() == b.d1());
|
||||
U_ASSERT(b.d2() == d1());
|
||||
for (int32_t i = 0; i < d1(); i++) {
|
||||
for (int32_t j = 0; j < a.d1(); j++) {
|
||||
data_[i] += a.get(j) * b.get(j, i);
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Hadamard Product the values of another array of the same size into this one.
|
||||
inline Array1D& hadamardProduct(const ReadArray1D& a) {
|
||||
U_ASSERT(a.d1() == d1());
|
||||
for (int32_t i = 0; i < d1(); i++) {
|
||||
data_[i] *= a.get(i);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Add the Hadamard Product of two arrays of the same size into this one.
|
||||
inline Array1D& addHadamardProduct(const ReadArray1D& a, const ReadArray1D& b) {
|
||||
U_ASSERT(a.d1() == d1());
|
||||
U_ASSERT(b.d1() == d1());
|
||||
for (int32_t i = 0; i < d1(); i++) {
|
||||
data_[i] += a.get(i) * b.get(i);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Add the values of another array of the same size into this one.
|
||||
inline Array1D& add(const ReadArray1D& a) {
|
||||
U_ASSERT(a.d1() == d1());
|
||||
for (int32_t i = 0; i < d1(); i++) {
|
||||
data_[i] += a.get(i);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Assign the values of another array of the same size into this one.
|
||||
inline Array1D& assign(const ReadArray1D& a) {
|
||||
U_ASSERT(a.d1() == d1());
|
||||
for (int32_t i = 0; i < d1(); i++) {
|
||||
data_[i] = a.get(i);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Apply tanh to all the elements in the array.
|
||||
inline Array1D& tanh() {
|
||||
return tanh(*this);
|
||||
}
|
||||
|
||||
// Apply tanh of a and store into this array.
|
||||
inline Array1D& tanh(const Array1D& a) {
|
||||
U_ASSERT(a.d1() == d1());
|
||||
for (int32_t i = 0; i < d1_; i++) {
|
||||
data_[i] = std::tanh(a.get(i));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Apply sigmoid to all the elements in the array.
|
||||
inline Array1D& sigmoid() {
|
||||
for (int32_t i = 0; i < d1_; i++) {
|
||||
data_[i] = 1.0f/(1.0f + expf(-data_[i]));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline Array1D& clear() {
|
||||
uprv_memset(data_, 0, d1_ * sizeof(float));
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
void* memory_;
|
||||
float* data_;
|
||||
int32_t d1_;
|
||||
};
|
||||
|
||||
Array1D::~Array1D()
|
||||
{
|
||||
uprv_free(memory_);
|
||||
}
|
||||
|
||||
class Array2D : public ReadArray2D {
|
||||
public:
|
||||
Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {}
|
||||
Array2D(int32_t d1, int32_t d2, UErrorCode &status)
|
||||
: memory_(uprv_malloc(d1 * d2 * sizeof(float))),
|
||||
data_(static_cast<float*>(memory_)), d1_(d1), d2_(d2) {
|
||||
if (U_SUCCESS(status)) {
|
||||
if (memory_ == nullptr) {
|
||||
status = U_MEMORY_ALLOCATION_ERROR;
|
||||
return;
|
||||
}
|
||||
clear();
|
||||
}
|
||||
}
|
||||
virtual ~Array2D();
|
||||
|
||||
// ReadArray2D methods.
|
||||
virtual int32_t d1() const override { return d1_; }
|
||||
virtual int32_t d2() const override { return d2_; }
|
||||
virtual float get(int32_t i, int32_t j) const override {
|
||||
U_ASSERT(i < d1_);
|
||||
U_ASSERT(j < d2_);
|
||||
return data_[i * d2_ + j];
|
||||
}
|
||||
|
||||
inline Array1D row(int32_t i) const {
|
||||
U_ASSERT(i < d1_);
|
||||
return Array1D(data_ + i * d2_, d2_);
|
||||
}
|
||||
|
||||
inline Array2D& clear() {
|
||||
uprv_memset(data_, 0, d1_ * d2_ * sizeof(float));
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
void* memory_;
|
||||
float* data_;
|
||||
int32_t d1_;
|
||||
int32_t d2_;
|
||||
};
|
||||
|
||||
Array2D::~Array2D()
|
||||
{
|
||||
uprv_free(memory_);
|
||||
}
|
||||
|
||||
typedef enum {
|
||||
BEGIN,
|
||||
INSIDE,
|
||||
END,
|
||||
SINGLE
|
||||
} LSTMClass;
|
||||
|
||||
typedef enum {
|
||||
UNKNOWN,
|
||||
CODE_POINTS,
|
||||
GRAPHEME_CLUSTER,
|
||||
} EmbeddingType;
|
||||
|
||||
struct LSTMData : public UMemory {
|
||||
LSTMData(UResourceBundle* rb, UErrorCode &status);
|
||||
~LSTMData();
|
||||
UHashtable* fDict;
|
||||
EmbeddingType fType;
|
||||
const char16_t* fName;
|
||||
ConstArray2D fEmbedding;
|
||||
ConstArray2D fForwardW;
|
||||
ConstArray2D fForwardU;
|
||||
ConstArray1D fForwardB;
|
||||
ConstArray2D fBackwardW;
|
||||
ConstArray2D fBackwardU;
|
||||
ConstArray1D fBackwardB;
|
||||
ConstArray2D fOutputW;
|
||||
ConstArray1D fOutputB;
|
||||
|
||||
private:
|
||||
UResourceBundle* fBundle;
|
||||
};
|
||||
|
||||
LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
|
||||
: fDict(nullptr), fType(UNKNOWN), fName(nullptr),
|
||||
fBundle(rb)
|
||||
{
|
||||
if (U_FAILURE(status)) {
|
||||
return;
|
||||
}
|
||||
if (IEEE_754 != 1) {
|
||||
status = U_UNSUPPORTED_ERROR;
|
||||
return;
|
||||
}
|
||||
LocalUResourceBundlePointer embeddings_res(
|
||||
ures_getByKey(rb, "embeddings", nullptr, &status));
|
||||
int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status);
|
||||
LocalUResourceBundlePointer hunits_res(
|
||||
ures_getByKey(rb, "hunits", nullptr, &status));
|
||||
if (U_FAILURE(status)) return;
|
||||
int32_t hunits = ures_getInt(hunits_res.getAlias(), &status);
|
||||
const char16_t* type = ures_getStringByKey(rb, "type", nullptr, &status);
|
||||
if (U_FAILURE(status)) return;
|
||||
if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) {
|
||||
fType = CODE_POINTS;
|
||||
} else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) {
|
||||
fType = GRAPHEME_CLUSTER;
|
||||
}
|
||||
fName = ures_getStringByKey(rb, "model", nullptr, &status);
|
||||
LocalUResourceBundlePointer dataRes(ures_getByKey(rb, "data", nullptr, &status));
|
||||
if (U_FAILURE(status)) return;
|
||||
int32_t data_len = 0;
|
||||
const int32_t* data = ures_getIntVector(dataRes.getAlias(), &data_len, &status);
|
||||
fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status);
|
||||
|
||||
StackUResourceBundle stackTempBundle;
|
||||
ResourceDataValue value;
|
||||
ures_getValueWithFallback(rb, "dict", stackTempBundle.getAlias(), value, status);
|
||||
ResourceArray stringArray = value.getArray(status);
|
||||
int32_t num_index = stringArray.getSize();
|
||||
if (U_FAILURE(status)) { return; }
|
||||
|
||||
// put dict into hash
|
||||
int32_t stringLength;
|
||||
for (int32_t idx = 0; idx < num_index; idx++) {
|
||||
stringArray.getValue(idx, value);
|
||||
const char16_t* str = value.getString(stringLength, status);
|
||||
uhash_putiAllowZero(fDict, (void*)str, idx, &status);
|
||||
if (U_FAILURE(status)) return;
|
||||
#ifdef LSTM_VECTORIZER_DEBUG
|
||||
printf("Assign [");
|
||||
while (*str != 0x0000) {
|
||||
printf("U+%04x ", *str);
|
||||
str++;
|
||||
}
|
||||
printf("] map to %d\n", idx-1);
|
||||
#endif
|
||||
}
|
||||
int32_t mat1_size = (num_index + 1) * embedding_size;
|
||||
int32_t mat2_size = embedding_size * 4 * hunits;
|
||||
int32_t mat3_size = hunits * 4 * hunits;
|
||||
int32_t mat4_size = 4 * hunits;
|
||||
int32_t mat5_size = mat2_size;
|
||||
int32_t mat6_size = mat3_size;
|
||||
int32_t mat7_size = mat4_size;
|
||||
int32_t mat8_size = 2 * hunits * 4;
|
||||
#if U_DEBUG
|
||||
int32_t mat9_size = 4;
|
||||
U_ASSERT(data_len == mat1_size + mat2_size + mat3_size + mat4_size + mat5_size +
|
||||
mat6_size + mat7_size + mat8_size + mat9_size);
|
||||
#endif
|
||||
|
||||
fEmbedding.init(data, (num_index + 1), embedding_size);
|
||||
data += mat1_size;
|
||||
fForwardW.init(data, embedding_size, 4 * hunits);
|
||||
data += mat2_size;
|
||||
fForwardU.init(data, hunits, 4 * hunits);
|
||||
data += mat3_size;
|
||||
fForwardB.init(data, 4 * hunits);
|
||||
data += mat4_size;
|
||||
fBackwardW.init(data, embedding_size, 4 * hunits);
|
||||
data += mat5_size;
|
||||
fBackwardU.init(data, hunits, 4 * hunits);
|
||||
data += mat6_size;
|
||||
fBackwardB.init(data, 4 * hunits);
|
||||
data += mat7_size;
|
||||
fOutputW.init(data, 2 * hunits, 4);
|
||||
data += mat8_size;
|
||||
fOutputB.init(data, 4);
|
||||
}
|
||||
|
||||
LSTMData::~LSTMData() {
|
||||
uhash_close(fDict);
|
||||
ures_close(fBundle);
|
||||
}
|
||||
|
||||
class Vectorizer : public UMemory {
|
||||
public:
|
||||
Vectorizer(UHashtable* dict) : fDict(dict) {}
|
||||
virtual ~Vectorizer();
|
||||
virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
|
||||
UVector32 &offsets, UVector32 &indices,
|
||||
UErrorCode &status) const = 0;
|
||||
protected:
|
||||
int32_t stringToIndex(const char16_t* str) const {
|
||||
UBool found = false;
|
||||
int32_t ret = uhash_getiAndFound(fDict, (const void*)str, &found);
|
||||
if (!found) {
|
||||
ret = fDict->count;
|
||||
}
|
||||
#ifdef LSTM_VECTORIZER_DEBUG
|
||||
printf("[");
|
||||
while (*str != 0x0000) {
|
||||
printf("U+%04x ", *str);
|
||||
str++;
|
||||
}
|
||||
printf("] map to %d\n", ret);
|
||||
#endif
|
||||
return ret;
|
||||
}
|
||||
|
||||
private:
|
||||
UHashtable* fDict;
|
||||
};
|
||||
|
||||
Vectorizer::~Vectorizer()
|
||||
{
|
||||
}
|
||||
|
||||
class CodePointsVectorizer : public Vectorizer {
|
||||
public:
|
||||
CodePointsVectorizer(UHashtable* dict) : Vectorizer(dict) {}
|
||||
virtual ~CodePointsVectorizer();
|
||||
virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
|
||||
UVector32 &offsets, UVector32 &indices,
|
||||
UErrorCode &status) const override;
|
||||
};
|
||||
|
||||
CodePointsVectorizer::~CodePointsVectorizer()
|
||||
{
|
||||
}
|
||||
|
||||
void CodePointsVectorizer::vectorize(
|
||||
UText *text, int32_t startPos, int32_t endPos,
|
||||
UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
|
||||
{
|
||||
if (offsets.ensureCapacity(endPos - startPos, status) &&
|
||||
indices.ensureCapacity(endPos - startPos, status)) {
|
||||
if (U_FAILURE(status)) return;
|
||||
utext_setNativeIndex(text, startPos);
|
||||
int32_t current;
|
||||
char16_t str[2] = {0, 0};
|
||||
while (U_SUCCESS(status) &&
|
||||
(current = static_cast<int32_t>(utext_getNativeIndex(text))) < endPos) {
|
||||
// Since the LSTMBreakEngine is currently only accept chars in BMP,
|
||||
// we can ignore the possibility of hitting supplementary code
|
||||
// point.
|
||||
str[0] = static_cast<char16_t>(utext_next32(text));
|
||||
U_ASSERT(!U_IS_SURROGATE(str[0]));
|
||||
offsets.addElement(current, status);
|
||||
indices.addElement(stringToIndex(str), status);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class GraphemeClusterVectorizer : public Vectorizer {
|
||||
public:
|
||||
GraphemeClusterVectorizer(UHashtable* dict)
|
||||
: Vectorizer(dict)
|
||||
{
|
||||
}
|
||||
virtual ~GraphemeClusterVectorizer();
|
||||
virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
|
||||
UVector32 &offsets, UVector32 &indices,
|
||||
UErrorCode &status) const override;
|
||||
};
|
||||
|
||||
GraphemeClusterVectorizer::~GraphemeClusterVectorizer()
|
||||
{
|
||||
}
|
||||
|
||||
constexpr int32_t MAX_GRAPHEME_CLSTER_LENGTH = 10;
|
||||
|
||||
void GraphemeClusterVectorizer::vectorize(
|
||||
UText *text, int32_t startPos, int32_t endPos,
|
||||
UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
|
||||
{
|
||||
if (U_FAILURE(status)) return;
|
||||
if (!offsets.ensureCapacity(endPos - startPos, status) ||
|
||||
!indices.ensureCapacity(endPos - startPos, status)) {
|
||||
return;
|
||||
}
|
||||
if (U_FAILURE(status)) return;
|
||||
LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status));
|
||||
if (U_FAILURE(status)) return;
|
||||
graphemeIter->setText(text, status);
|
||||
if (U_FAILURE(status)) return;
|
||||
|
||||
if (startPos != 0) {
|
||||
graphemeIter->preceding(startPos);
|
||||
}
|
||||
int32_t last = startPos;
|
||||
int32_t current = startPos;
|
||||
char16_t str[MAX_GRAPHEME_CLSTER_LENGTH];
|
||||
while ((current = graphemeIter->next()) != BreakIterator::DONE) {
|
||||
if (current >= endPos) {
|
||||
break;
|
||||
}
|
||||
if (current > startPos) {
|
||||
utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
|
||||
if (U_FAILURE(status)) return;
|
||||
offsets.addElement(last, status);
|
||||
indices.addElement(stringToIndex(str), status);
|
||||
if (U_FAILURE(status)) return;
|
||||
}
|
||||
last = current;
|
||||
}
|
||||
if (U_FAILURE(status) || last >= endPos) {
|
||||
return;
|
||||
}
|
||||
utext_extract(text, last, endPos, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
|
||||
if (U_SUCCESS(status)) {
|
||||
offsets.addElement(last, status);
|
||||
indices.addElement(stringToIndex(str), status);
|
||||
}
|
||||
}
|
||||
|
||||
// Computing LSTM as stated in
|
||||
// https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
|
||||
// ifco is temp array allocate outside which does not need to be
|
||||
// input/output value but could avoid unnecessary memory alloc/free if passing
|
||||
// in.
|
||||
void compute(
|
||||
int32_t hunits,
|
||||
const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b,
|
||||
const ReadArray1D& x, Array1D& h, Array1D& c,
|
||||
Array1D& ifco)
|
||||
{
|
||||
// ifco = x * W + h * U + b
|
||||
ifco.assign(b)
|
||||
.addDotProduct(x, W)
|
||||
.addDotProduct(h, U);
|
||||
|
||||
ifco.slice(0*hunits, hunits).sigmoid(); // i: sigmod
|
||||
ifco.slice(1*hunits, hunits).sigmoid(); // f: sigmoid
|
||||
ifco.slice(2*hunits, hunits).tanh(); // c_: tanh
|
||||
ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod
|
||||
|
||||
c.hadamardProduct(ifco.slice(hunits, hunits))
|
||||
.addHadamardProduct(ifco.slice(0, hunits), ifco.slice(2*hunits, hunits));
|
||||
|
||||
h.tanh(c)
|
||||
.hadamardProduct(ifco.slice(3*hunits, hunits));
|
||||
}
|
||||
|
||||
// Minimum word size
|
||||
static const int32_t MIN_WORD = 2;
|
||||
|
||||
// Minimum number of characters for two words
|
||||
static const int32_t MIN_WORD_SPAN = MIN_WORD * 2;
|
||||
|
||||
int32_t
|
||||
LSTMBreakEngine::divideUpDictionaryRange( UText *text,
|
||||
int32_t startPos,
|
||||
int32_t endPos,
|
||||
UVector32 &foundBreaks,
|
||||
UBool /* isPhraseBreaking */,
|
||||
UErrorCode& status) const {
|
||||
if (U_FAILURE(status)) return 0;
|
||||
int32_t beginFoundBreakSize = foundBreaks.size();
|
||||
utext_setNativeIndex(text, startPos);
|
||||
utext_moveIndex32(text, MIN_WORD_SPAN);
|
||||
if (utext_getNativeIndex(text) >= endPos) {
|
||||
return 0; // Not enough characters for two words
|
||||
}
|
||||
utext_setNativeIndex(text, startPos);
|
||||
|
||||
UVector32 offsets(status);
|
||||
UVector32 indices(status);
|
||||
if (U_FAILURE(status)) return 0;
|
||||
fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status);
|
||||
if (U_FAILURE(status)) return 0;
|
||||
int32_t* offsetsBuf = offsets.getBuffer();
|
||||
int32_t* indicesBuf = indices.getBuffer();
|
||||
|
||||
int32_t input_seq_len = indices.size();
|
||||
int32_t hunits = fData->fForwardU.d1();
|
||||
|
||||
// ----- Begin of all the Array memory allocation needed for this function
|
||||
// Allocate temp array used inside compute()
|
||||
Array1D ifco(4 * hunits, status);
|
||||
|
||||
Array1D c(hunits, status);
|
||||
Array1D logp(4, status);
|
||||
|
||||
// TODO: limit size of hBackward. If input_seq_len is too big, we could
|
||||
// run out of memory.
|
||||
// Backward LSTM
|
||||
Array2D hBackward(input_seq_len, hunits, status);
|
||||
|
||||
// Allocate fbRow and slice the internal array in two.
|
||||
Array1D fbRow(2 * hunits, status);
|
||||
|
||||
// ----- End of all the Array memory allocation needed for this function
|
||||
if (U_FAILURE(status)) return 0;
|
||||
|
||||
// To save the needed memory usage, the following is different from the
|
||||
// Python or ICU4X implementation. We first perform the Backward LSTM
|
||||
// and then merge the iteration of the forward LSTM and the output layer
|
||||
// together because we only neetdto remember the h[t-1] for Forward LSTM.
|
||||
for (int32_t i = input_seq_len - 1; i >= 0; i--) {
|
||||
Array1D hRow = hBackward.row(i);
|
||||
if (i != input_seq_len - 1) {
|
||||
hRow.assign(hBackward.row(i+1));
|
||||
}
|
||||
#ifdef LSTM_DEBUG
|
||||
printf("hRow %d\n", i);
|
||||
hRow.print();
|
||||
printf("indicesBuf[%d] = %d\n", i, indicesBuf[i]);
|
||||
printf("fData->fEmbedding.row(indicesBuf[%d]):\n", i);
|
||||
fData->fEmbedding.row(indicesBuf[i]).print();
|
||||
#endif // LSTM_DEBUG
|
||||
compute(hunits,
|
||||
fData->fBackwardW, fData->fBackwardU, fData->fBackwardB,
|
||||
fData->fEmbedding.row(indicesBuf[i]),
|
||||
hRow, c, ifco);
|
||||
}
|
||||
|
||||
|
||||
Array1D forwardRow = fbRow.slice(0, hunits); // point to first half of data in fbRow.
|
||||
Array1D backwardRow = fbRow.slice(hunits, hunits); // point to second half of data n fbRow.
|
||||
|
||||
// The following iteration merge the forward LSTM and the output layer
|
||||
// together.
|
||||
c.clear(); // reuse c since it is the same size.
|
||||
for (int32_t i = 0; i < input_seq_len; i++) {
|
||||
#ifdef LSTM_DEBUG
|
||||
printf("forwardRow %d\n", i);
|
||||
forwardRow.print();
|
||||
#endif // LSTM_DEBUG
|
||||
// Forward LSTM
|
||||
// Calculate the result into forwardRow, which point to the data in the first half
|
||||
// of fbRow.
|
||||
compute(hunits,
|
||||
fData->fForwardW, fData->fForwardU, fData->fForwardB,
|
||||
fData->fEmbedding.row(indicesBuf[i]),
|
||||
forwardRow, c, ifco);
|
||||
|
||||
// assign the data from hBackward.row(i) to second half of fbRowa.
|
||||
backwardRow.assign(hBackward.row(i));
|
||||
|
||||
logp.assign(fData->fOutputB).addDotProduct(fbRow, fData->fOutputW);
|
||||
#ifdef LSTM_DEBUG
|
||||
printf("backwardRow %d\n", i);
|
||||
backwardRow.print();
|
||||
printf("logp %d\n", i);
|
||||
logp.print();
|
||||
#endif // LSTM_DEBUG
|
||||
|
||||
// current = argmax(logp)
|
||||
LSTMClass current = static_cast<LSTMClass>(logp.maxIndex());
|
||||
// BIES logic.
|
||||
if (current == BEGIN || current == SINGLE) {
|
||||
if (i != 0) {
|
||||
foundBreaks.addElement(offsetsBuf[i], status);
|
||||
if (U_FAILURE(status)) return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
return foundBreaks.size() - beginFoundBreakSize;
|
||||
}
|
||||
|
||||
Vectorizer* createVectorizer(const LSTMData* data, UErrorCode &status) {
|
||||
if (U_FAILURE(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
switch (data->fType) {
|
||||
case CODE_POINTS:
|
||||
return new CodePointsVectorizer(data->fDict);
|
||||
break;
|
||||
case GRAPHEME_CLUSTER:
|
||||
return new GraphemeClusterVectorizer(data->fDict);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
UPRV_UNREACHABLE_EXIT;
|
||||
}
|
||||
|
||||
LSTMBreakEngine::LSTMBreakEngine(const LSTMData* data, const UnicodeSet& set, UErrorCode &status)
|
||||
: DictionaryBreakEngine(), fData(data), fVectorizer(createVectorizer(fData, status))
|
||||
{
|
||||
if (U_FAILURE(status)) {
|
||||
fData = nullptr; // If failure, we should not delete fData in destructor because the caller will do so.
|
||||
return;
|
||||
}
|
||||
setCharacters(set);
|
||||
}
|
||||
|
||||
LSTMBreakEngine::~LSTMBreakEngine() {
|
||||
delete fData;
|
||||
delete fVectorizer;
|
||||
}
|
||||
|
||||
const char16_t* LSTMBreakEngine::name() const {
|
||||
return fData->fName;
|
||||
}
|
||||
|
||||
UnicodeString defaultLSTM(UScriptCode script, UErrorCode& status) {
|
||||
// open root from brkitr tree.
|
||||
UResourceBundle *b = ures_open(U_ICUDATA_BRKITR, "", &status);
|
||||
b = ures_getByKeyWithFallback(b, "lstm", b, &status);
|
||||
UnicodeString result = ures_getUnicodeStringByKey(b, uscript_getShortName(script), &status);
|
||||
ures_close(b);
|
||||
return result;
|
||||
}
|
||||
|
||||
U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UErrorCode& status)
|
||||
{
|
||||
if (script != USCRIPT_KHMER && script != USCRIPT_LAO && script != USCRIPT_MYANMAR && script != USCRIPT_THAI) {
|
||||
return nullptr;
|
||||
}
|
||||
UnicodeString name = defaultLSTM(script, status);
|
||||
if (U_FAILURE(status)) return nullptr;
|
||||
CharString namebuf;
|
||||
namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.'));
|
||||
|
||||
LocalUResourceBundlePointer rb(
|
||||
ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status));
|
||||
if (U_FAILURE(status)) return nullptr;
|
||||
|
||||
return CreateLSTMData(rb.orphan(), status);
|
||||
}
|
||||
|
||||
U_CAPI const LSTMData* U_EXPORT2 CreateLSTMData(UResourceBundle* rb, UErrorCode& status)
|
||||
{
|
||||
return new LSTMData(rb, status);
|
||||
}
|
||||
|
||||
U_CAPI const LanguageBreakEngine* U_EXPORT2
|
||||
CreateLSTMBreakEngine(UScriptCode script, const LSTMData* data, UErrorCode& status)
|
||||
{
|
||||
UnicodeString unicodeSetString;
|
||||
switch(script) {
|
||||
case USCRIPT_THAI:
|
||||
unicodeSetString = UnicodeString(u"[[:Thai:]&[:LineBreak=SA:]]");
|
||||
break;
|
||||
case USCRIPT_MYANMAR:
|
||||
unicodeSetString = UnicodeString(u"[[:Mymr:]&[:LineBreak=SA:]]");
|
||||
break;
|
||||
default:
|
||||
delete data;
|
||||
return nullptr;
|
||||
}
|
||||
UnicodeSet unicodeSet;
|
||||
unicodeSet.applyPattern(unicodeSetString, status);
|
||||
const LanguageBreakEngine* engine = new LSTMBreakEngine(data, unicodeSet, status);
|
||||
if (U_FAILURE(status) || engine == nullptr) {
|
||||
if (engine != nullptr) {
|
||||
delete engine;
|
||||
} else {
|
||||
status = U_MEMORY_ALLOCATION_ERROR;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
return engine;
|
||||
}
|
||||
|
||||
U_CAPI void U_EXPORT2 DeleteLSTMData(const LSTMData* data)
|
||||
{
|
||||
delete data;
|
||||
}
|
||||
|
||||
U_CAPI const char16_t* U_EXPORT2 LSTMDataName(const LSTMData* data)
|
||||
{
|
||||
return data->fName;
|
||||
}
|
||||
|
||||
U_NAMESPACE_END
|
||||
|
||||
#endif /* #if !UCONFIG_NO_BREAK_ITERATION */
|
Reference in New Issue
Block a user