Skip to content

Commit

Permalink
internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705877615
  • Loading branch information
ericsalo authored and copybara-github committed Feb 20, 2025
1 parent 871b5ba commit 06bd53c
Show file tree
Hide file tree
Showing 15 changed files with 110 additions and 110 deletions.
16 changes: 8 additions & 8 deletions mozolm/models/ngram_char_fst_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@
#include "absl/memory/memory.h"
#include "nisaba/port/utf8_util.h"

using fst::MATCH_INPUT;
using fst::Matcher;
using fst::StdArc;
using fst::StdVectorFst;
using fst::Times;
using nlp_fst::MATCH_INPUT;
using nlp_fst::Matcher;
using nlp_fst::StdArc;
using nlp_fst::StdVectorFst;
using nlp_fst::Times;

namespace mozolm {
namespace models {

fst::StdArc::Label NGramCharFstModel::SymLabel(int utf8_sym) const {
nlp_fst::StdArc::Label NGramCharFstModel::SymLabel(int utf8_sym) const {
if (utf8_sym == 0) return utf8_sym;
const std::string u_char = nisaba::utf8::EncodeUnicodeChar(utf8_sym);
fst::StdArc::Label label = fst_->InputSymbols()->Find(u_char);
if (label == fst::kNoSymbol) {
nlp_fst::StdArc::Label label = fst_->InputSymbols()->Find(u_char);
if (label == nlp_fst::kNoSymbol) {
label = oov_label_;
}
return label;
Expand Down
10 changes: 5 additions & 5 deletions mozolm/models/ngram_char_fst_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ class NGramCharFstModel : public NGramFstModel {
protected:
// Computes negative log probability for observing the supplied label in a
// given state.
fst::StdArc::Weight LabelCostInState(fst::StdArc::StateId state,
fst::StdArc::Label label) const;
nlp_fst::StdArc::Weight LabelCostInState(nlp_fst::StdArc::StateId state,
nlp_fst::StdArc::Label label) const;

private:
fst::StdArc::Label SymLabel(int utf8_sym) const;
nlp_fst::StdArc::Label SymLabel(int utf8_sym) const;

// Returns negative log probability of the end-of-string at the given state.
fst::StdArc::Weight FinalCostInState(
fst::StdArc::StateId state) const;
nlp_fst::StdArc::Weight FinalCostInState(
nlp_fst::StdArc::StateId state) const;
};

} // namespace models
Expand Down
2 changes: 1 addition & 1 deletion mozolm/models/ngram_char_fst_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
#include "nisaba/port/test_utils.h"

using nisaba::testing::TestFilePath;
using fst::StdArc;
using nlp_fst::StdArc;

namespace mozolm {
namespace models {
Expand Down
12 changes: 6 additions & 6 deletions mozolm/models/ngram_fst_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"

using fst::MATCH_INPUT;
using fst::Matcher;
using fst::StdArc;
using fst::StdVectorFst;
using fst::SymbolTable;
using nlp_fst::MATCH_INPUT;
using nlp_fst::Matcher;
using nlp_fst::StdArc;
using nlp_fst::StdVectorFst;
using nlp_fst::SymbolTable;

namespace mozolm {
namespace models {
Expand All @@ -43,7 +43,7 @@ absl::Status NGramFstModel::Read(const ModelStorage &storage) {
return absl::InvalidArgumentError("Model file not specified");
}
GOOGLE_LOG(INFO) << "Initializing from " << storage.model_file() << " ...";
std::unique_ptr<fst::StdVectorFst> fst;
std::unique_ptr<nlp_fst::StdVectorFst> fst;
fst.reset(StdVectorFst::Read(storage.model_file()));
if (!fst) {
return absl::NotFoundError(absl::StrCat("Failed to read FST from ",
Expand Down
20 changes: 10 additions & 10 deletions mozolm/models/ngram_fst_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,32 +42,32 @@ class NGramFstModel : public LanguageModel {
int64_t count) override;

// Returns underlying FST, which must be initialized.
const fst::StdVectorFst &fst() const { return *fst_; }
const nlp_fst::StdVectorFst &fst() const { return *fst_; }

fst::StdArc::Label oov_label() const { return oov_label_; }
nlp_fst::StdArc::Label oov_label() const { return oov_label_; }

protected:
NGramFstModel() = default;

// Returns the next state reached by arc labeled with label from state s.
// If the label is out-of-vocabulary, it will return the unigram state.
fst::StdArc::StateId NextModelState(
fst::StdArc::StateId current_state,
fst::StdArc::Label label) const;
nlp_fst::StdArc::StateId NextModelState(
nlp_fst::StdArc::StateId current_state,
nlp_fst::StdArc::Label label) const;

// Language model represented by vector FST.
std::unique_ptr<const fst::StdVectorFst> fst_;
std::unique_ptr<const nlp_fst::StdVectorFst> fst_;

// N-Gram model helper wrapping the FST above.
std::unique_ptr<const ngram::NGramModel<fst::StdArc>> model_;
std::unique_ptr<const ngram::NGramModel<nlp_fst::StdArc>> model_;

// Label for the unknown symbol, if any.
fst::StdArc::Label oov_label_ = fst::kNoSymbol;
nlp_fst::StdArc::Label oov_label_ = nlp_fst::kNoSymbol;

// Checks the current state and sets it to the unigram state if less than
// zero.
fst::StdArc::StateId CheckCurrentState(
fst::StdArc::StateId state) const;
nlp_fst::StdArc::StateId CheckCurrentState(
nlp_fst::StdArc::StateId state) const;

private:
// Performs model sanity check.
Expand Down
14 changes: 7 additions & 7 deletions mozolm/models/ngram_word_fst_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ using nisaba::utf8::EncodeUnicodeChar;
using nisaba::utf8::StrSplitByChar;

using absl::StatusOr;
using fst::ArcIterator;
using fst::MATCH_INPUT;
using fst::Matcher;
using fst::StdArc;
using fst::StdVectorFst;
using fst::SymbolTable;
using fst::Times;
using nlp_fst::ArcIterator;
using nlp_fst::MATCH_INPUT;
using nlp_fst::Matcher;
using nlp_fst::StdArc;
using nlp_fst::StdVectorFst;
using nlp_fst::SymbolTable;
using nlp_fst::Times;

namespace mozolm {
namespace models {
Expand Down
4 changes: 2 additions & 2 deletions mozolm/models/ngram_word_fst_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class NGramImplicitStates {
public:
NGramImplicitStates() = default;

NGramImplicitStates(const fst::StdVectorFst& fst,
NGramImplicitStates(const nlp_fst::StdVectorFst& fst,
int first_char_begin_index, int first_char_end_index);

// Returns the state if already exists, creates it otherwise.
Expand Down Expand Up @@ -168,7 +168,7 @@ class NGramWordFstModel : public NGramFstModel {
int FindOldestLastAccessedCache() const;

// Returns new cache index for given state.
absl::Status GetNewCacheIndex(fst::StdArc::StateId s,
absl::Status GetNewCacheIndex(nlp_fst::StdArc::StateId s,
const std::vector<double>& weights);

// Returns cache index if it exists, creates new cache entry otherwise.
Expand Down
10 changes: 5 additions & 5 deletions mozolm/models/ngram_word_fst_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
#include "nisaba/port/utf8_util.h"

using ::nisaba::testing::TestFilePath;
using ::fst::ArcSort;
using ::fst::ILabelCompare;
using ::fst::StdArc;
using ::fst::StdVectorFst;
using ::fst::SymbolTable;
using ::nlp_fst::ArcSort;
using ::nlp_fst::ILabelCompare;
using ::nlp_fst::StdArc;
using ::nlp_fst::StdVectorFst;
using ::nlp_fst::SymbolTable;

namespace mozolm {
namespace models {
Expand Down
16 changes: 8 additions & 8 deletions mozolm/models/ppm_as_fst_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ using nisaba::file::ReadLines;
using nisaba::utf8::EncodeUnicodeChar;
using nisaba::utf8::StrSplitByChar;

using fst::ArcIterator;
using fst::ILabelCompare;
using fst::Log64Weight;
using fst::MutableArcIterator;
using fst::StdArc;
using fst::StdVectorFst;
using fst::SymbolTable;
using fst::SymbolTableIterator;
using nlp_fst::ArcIterator;
using nlp_fst::ILabelCompare;
using nlp_fst::Log64Weight;
using nlp_fst::MutableArcIterator;
using nlp_fst::StdArc;
using nlp_fst::StdVectorFst;
using nlp_fst::SymbolTable;
using nlp_fst::SymbolTableIterator;

namespace impl {
namespace {
Expand Down
60 changes: 30 additions & 30 deletions mozolm/models/ppm_as_fst_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class PpmStateCache {
}

// Fills in LMScores proto from values in cached state.
bool FillLMScores(const fst::SymbolTable& syms, LMScores* response) const;
bool FillLMScores(const nlp_fst::SymbolTable& syms, LMScores* response) const;

private:
int state_; // Index of state being cached.
Expand All @@ -171,7 +171,7 @@ class PpmAsFstModel : public LanguageModel {
absl::Status WriteFst(const std::string& ofile) const override;

// Returns fst_.
const fst::StdVectorFst GetFst() const { return *fst_; }
const nlp_fst::StdVectorFst GetFst() const { return *fst_; }

// Provides the state reached from state following utf8_sym.
int NextState(int state, int utf8_sym) override;
Expand Down Expand Up @@ -223,8 +223,8 @@ class PpmAsFstModel : public LanguageModel {
absl::Status CalculateStateOrders(bool save_state_orders);

// Determines whether new state needs to be created for arc.
absl::StatusOr<bool> NeedsNewState(fst::StdArc::StateId curr_state,
fst::StdArc::StateId next_state) const;
absl::StatusOr<bool> NeedsNewState(nlp_fst::StdArc::StateId curr_state,
nlp_fst::StdArc::StateId next_state) const;

// Adds extra characters to unigram of model if provided.
absl::Status AddExtraCharacters(const std::string& input_string);
Expand All @@ -233,23 +233,23 @@ class PpmAsFstModel : public LanguageModel {
// for each item in the vocabulary, matching indices with the symbol table. By
// convention, index 0 is for final cost. Checks for empty states and ensures
// backoff states are cached.
absl::Status UpdateCacheAtState(fst::StdArc::StateId s);
absl::Status UpdateCacheAtState(nlp_fst::StdArc::StateId s);

// Initializes negative log probabilities for cache based on backoff.
std::vector<double> InitCacheProbs(fst::StdArc::StateId s,
fst::StdArc::StateId backoff_state,
std::vector<double> InitCacheProbs(nlp_fst::StdArc::StateId s,
nlp_fst::StdArc::StateId backoff_state,
const PpmStateCache& backoff_cache,
double denominator) const;

// Initializes the origin and destination states for cache based on backoff.
std::vector<int> InitCacheStates(fst::StdArc::StateId s,
fst::StdArc::StateId backoff_state,
std::vector<int> InitCacheStates(nlp_fst::StdArc::StateId s,
nlp_fst::StdArc::StateId backoff_state,
const PpmStateCache& backoff_cache,
bool arc_origin) const;

// Fills in values for states and probs vectors from state for cache.
absl::Status UpdateCacheStatesAndProbs(
fst::StdArc::StateId s, fst::StdArc::StateId backoff_state,
nlp_fst::StdArc::StateId s, nlp_fst::StdArc::StateId backoff_state,
double denominator, std::vector<int>* arc_origin_states,
std::vector<int>* destination_states,
std::vector<double>* neg_log_probabilities);
Expand All @@ -258,57 +258,57 @@ class PpmAsFstModel : public LanguageModel {
// for each item in the vocabulary, matching indices with the symbol table. By
// convention, index 0 is for final cost.
absl::Status UpdateCacheAtNonEmptyState(
fst::StdArc::StateId s, fst::StdArc::StateId backoff_state,
nlp_fst::StdArc::StateId s, nlp_fst::StdArc::StateId backoff_state,
const PpmStateCache& backoff_cache);

// Checks if lower order state caches have updated more recently.
bool LowerOrderCacheUpdated(fst::StdArc::StateId s) const;
bool LowerOrderCacheUpdated(nlp_fst::StdArc::StateId s) const;

// Ensures cache exists for state, creates it if not.
absl::StatusOr<PpmStateCache> EnsureCacheAtState(fst::StdArc::StateId s);
absl::StatusOr<PpmStateCache> EnsureCacheAtState(nlp_fst::StdArc::StateId s);

// Finds cache entry with oldest last access, for replacement.
int FindOldestLastAccessedCache() const;

// Establishes cache index, after performing garbage collection if needed.
absl::Status GetNewCacheIndex(fst::StdArc::StateId s);
absl::Status GetNewCacheIndex(nlp_fst::StdArc::StateId s);

// Adds new state to all required data structures and returns index.
absl::StatusOr<int> AddNewState(fst::StdArc::StateId backoff_dest_state);
absl::StatusOr<int> AddNewState(nlp_fst::StdArc::StateId backoff_dest_state);

// Returns origin state of arc with symbol from state s.
absl::StatusOr<int> GetArcOriginState(fst::StdArc::StateId s,
absl::StatusOr<int> GetArcOriginState(nlp_fst::StdArc::StateId s,
int sym_index);

// Returns destination state of arc with symbol from state s.
absl::StatusOr<int> GetDestinationState(fst::StdArc::StateId s,
absl::StatusOr<int> GetDestinationState(nlp_fst::StdArc::StateId s,
int sym_index);

// Returns probability of symbol leaving the current state.
absl::StatusOr<double> GetNegLogProb(fst::StdArc::StateId s,
absl::StatusOr<double> GetNegLogProb(nlp_fst::StdArc::StateId s,
int sym_index);

// Returns normalization value at the current state.
absl::StatusOr<double> GetNormalization(fst::StdArc::StateId s);
absl::StatusOr<double> GetNormalization(nlp_fst::StdArc::StateId s);

// Updates model at highest found state for given symbol.
absl::Status UpdateHighestFoundState(fst::StdArc::StateId curr_state,
absl::Status UpdateHighestFoundState(nlp_fst::StdArc::StateId curr_state,
int sym_index);

// Updates model at state where given symbol is not found.
absl::Status UpdateNotFoundState(fst::StdArc::StateId curr_state,
fst::StdArc::StateId highest_found_state,
fst::StdArc::StateId backoff_state,
absl::Status UpdateNotFoundState(nlp_fst::StdArc::StateId curr_state,
nlp_fst::StdArc::StateId highest_found_state,
nlp_fst::StdArc::StateId backoff_state,
int sym_index);

// Updates model with an observation of the sym_index at curr_state.
absl::StatusOr<fst::StdArc::StateId> UpdateModel(
fst::StdArc::StateId curr_state,
fst::StdArc::StateId highest_found_state, int sym_index);
absl::StatusOr<nlp_fst::StdArc::StateId> UpdateModel(
nlp_fst::StdArc::StateId curr_state,
nlp_fst::StdArc::StateId highest_found_state, int sym_index);

// Converts input string into linear FST at the character level, replacing
// characters not in possible_characters_ set (if non-empty) with kOovSymbol.
absl::StatusOr<fst::StdVectorFst> String2Fst(
absl::StatusOr<nlp_fst::StdVectorFst> String2Fst(
const std::string& input_string);

// Adds a single unigram count to every character.
Expand All @@ -319,10 +319,10 @@ class PpmAsFstModel : public LanguageModel {
double beta_; // Beta hyper-parameter for PPM.
bool static_model_; // Whether to use the model as static or dynamic.
std::vector<int> state_orders_; // Stores the order of each state.
std::unique_ptr<fst::StdVectorFst> fst_; // Model (counts) stored in FST.
std::unique_ptr<nlp_fst::StdVectorFst> fst_; // Model (counts) stored in FST.
// For counting character n-grams if training from text file.
std::unique_ptr<ngram::NGramCounter<fst::Log64Weight>> ngram_counter_;
std::unique_ptr<fst::SymbolTable> syms_; // Character symbols.
std::unique_ptr<ngram::NGramCounter<nlp_fst::Log64Weight>> ngram_counter_;
std::unique_ptr<nlp_fst::SymbolTable> syms_; // Character symbols.

// For caching probabilities and destination states for quick access.
int max_cache_size_; // Limit on caching for garbage collection.
Expand Down
12 changes: 6 additions & 6 deletions mozolm/models/ppm_as_fst_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ constexpr char kVocabFileName[] = "vocab.txt";

using ::nisaba::file::WriteTempTextFile;
using ::nisaba::utf8::DecodeSingleUnicodeChar;
using ::fst::ArcSort;
using ::fst::ILabelCompare;
using ::fst::Isomorphic;
using ::fst::StdArc;
using ::fst::StdVectorFst;
using ::fst::SymbolTable;
using ::nlp_fst::ArcSort;
using ::nlp_fst::ILabelCompare;
using ::nlp_fst::Isomorphic;
using ::nlp_fst::StdArc;
using ::nlp_fst::StdVectorFst;
using ::nlp_fst::SymbolTable;
using ::testing::DoubleEq;
using ::testing::Each;

Expand Down
Loading

0 comments on commit 06bd53c

Please sign in to comment.