blob: cfa56df6928edad2c6ed2e266730f46b5b1153f4 [file] [edit]
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cmath>
#include <iostream>
#include <memory>
#include "feature_extractor.h"
#include "feature_types.h"
#include "relevant_script_feature.h"
#include "script_detector.h"
#include "cld_3/protos/sentence.pb.h"
#include "sentence_features.h"
#include "task_context.h"
#include "utils.h"
#include "workspace.h"
namespace chrome_lang_id {
namespace relevant_script_feature_test {
namespace {
// Checks whether the expected and actual float feature values are within 0.0001
// of each other.
bool FeatureValuesNear(float expected_value, float actual_value) {
return std::abs(expected_value - actual_value) < 0.0001;
}
// Checks whether two sets of feature values are within an acceptable amount of
// each other.
bool FeaturesNear(const string &test_input,
const std::map<int, float> &expected_features,
const std::map<int, float> &actual_features) {
if (expected_features.size() != actual_features.size()) {
std::cout << " Failure for input: " << test_input << std::endl;
return false;
}
for (const auto &id_and_value : expected_features) {
const int id = id_and_value.first;
if (actual_features.count(id) == 0 ||
!FeatureValuesNear(expected_features.at(id), actual_features.at(id))) {
std::cout << " Failure for input: " << test_input << std::endl;
return false;
}
}
std::cout << " Success for input: " << test_input << std::endl;
return true;
}
// Checks whether the set of features is empty.
bool CheckFeaturesEmpty(const string &input,
const std::map<int, float> &actual_features) {
if (!actual_features.empty()) {
std::cout << " Failure for input: " << input << std::endl;
return false;
} else {
std::cout << " Success for input: " << input << std::endl;
return true;
}
}
} // namespace
static WholeSentenceFeature *rsf_factory() { return new RelevantScriptFeature; }
class RelevantScriptFeatureExtractor {
public:
RelevantScriptFeatureExtractor() {
if (WholeSentenceFeature::registry() == nullptr) {
// Create registry for our WholeSentenceFeature(s).
RegisterableClass<WholeSentenceFeature>::CreateRegistry(
"sentence feature function", "WholeSentenceFeature", __FILE__,
__LINE__);
}
// Register our WholeSentenceFeature(s).
// Register RelevantScriptFeature feature function.
static WholeSentenceFeature::Registry::Registrar rsf_registrar(
WholeSentenceFeature::registry(), "continuous-bag-of-relevant-scripts",
"RelevantScriptFeature", __FILE__, __LINE__, rsf_factory);
feature_extractor_.Parse("continuous-bag-of-relevant-scripts");
TaskContext context;
feature_extractor_.Setup(&context);
feature_extractor_.Init(&context);
feature_extractor_.RequestWorkspaces(&workspace_registry_);
}
// Returns "true" if feature extraction is successful, and "false" otherwise.
bool Extract(const string &text, std::map<int, float> *float_features) {
float_features->clear();
if (text.empty()) {
return true;
}
Sentence sentence;
sentence.set_text(text);
workspace_.Reset(workspace_registry_);
feature_extractor_.Preprocess(&workspace_, &sentence);
FeatureVector feature_vector;
feature_extractor_.ExtractFeatures(workspace_, sentence, &feature_vector);
for (int index = 0; index < feature_vector.size(); ++index) {
const FloatFeatureValue value =
FloatFeatureValue(feature_vector.value(index));
if (float_features->count(value.value.id) != 0) {
std::cout << " Failure: duplicate feature" << std::endl;
return false;
}
float_features->emplace(value.value.id, value.value.weight);
}
return true;
}
private:
WorkspaceSet workspace_;
WholeSentenceExtractor feature_extractor_;
// The registry of shared workspaces in the feature extractor.
WorkspaceRegistry workspace_registry_;
};
bool TestCommonCases() {
std::cout << "Running " << __FUNCTION__ << std::endl;
RelevantScriptFeatureExtractor extractor;
std::map<int, float> float_features;
bool test_successful = true;
string input = "just some plain text";
if (!extractor.Extract(input, &float_features) ||
!FeaturesNear(input, {{chrome_lang_id::kScriptOtherUtf8OneByte, 1.00}},
float_features)) {
test_successful = false;
}
input = "ヸヂ゠ヂ";
if (!extractor.Extract(input, &float_features) ||
!FeaturesNear(input, {{chrome_lang_id::kScriptKatakana, 1.00}},
float_features)) {
test_successful = false;
}
// 4 Latin letters mixed with 4 Katakana letters.
input = "ヸtヂe゠xtヂ";
if (!extractor.Extract(input, &float_features) ||
!FeaturesNear(input, {{chrome_lang_id::kScriptOtherUtf8OneByte, 0.5},
{chrome_lang_id::kScriptKatakana, 0.5}},
float_features)) {
test_successful = false;
}
input = "just some 121212%^^( ヸヂ゠ヂ text";
if (!extractor.Extract(input, &float_features) ||
!FeaturesNear(input, {{chrome_lang_id::kScriptOtherUtf8OneByte, 0.75},
{chrome_lang_id::kScriptKatakana, 0.25}},
float_features)) {
test_successful = false;
}
return test_successful;
}
bool TestCornerCases() {
std::cout << "Running " << __FUNCTION__ << std::endl;
RelevantScriptFeatureExtractor extractor;
std::map<int, float> float_features;
bool test_successful = true;
// Empty string.
string input = "";
if (!extractor.Extract(input, &float_features) ||
!CheckFeaturesEmpty(input, float_features)) {
test_successful = false;
}
// Only whitespaces.
input = " ";
if (!extractor.Extract(input, &float_features) ||
!CheckFeaturesEmpty(input, float_features)) {
test_successful = false;
}
// Only numbers and punctuation.
input = "12----)(";
if (!extractor.Extract(input, &float_features) ||
!CheckFeaturesEmpty(input, float_features)) {
test_successful = false;
}
// Only numbers, punctuation, and spaces.
input = "12--- - ) ( ";
if (!extractor.Extract(input, &float_features) ||
!CheckFeaturesEmpty(input, float_features)) {
test_successful = false;
}
// One UTF8 character by itself.
input = "ゟ";
if (!extractor.Extract(input, &float_features) ||
!FeaturesNear(input, {{chrome_lang_id::kScriptHiragana, 1.00}},
float_features)) {
test_successful = false;
}
input = "ה";
if (!extractor.Extract(input, &float_features) ||
!FeaturesNear(input, {{chrome_lang_id::kScriptHebrew, 1.00}},
float_features)) {
test_successful = false;
}
// One UTF8 character with some numbers / punctuation / spaces: character at
// one extremity or in the middle.
input = "1234ゟ";
if (!extractor.Extract(input, &float_features) ||
!FeaturesNear(input, {{chrome_lang_id::kScriptHiragana, 1.00}},
float_features)) {
test_successful = false;
}
input = "ゟ12-(";
if (!extractor.Extract(input, &float_features) ||
!FeaturesNear(input, {{chrome_lang_id::kScriptHiragana, 1.00}},
float_features)) {
test_successful = false;
}
input = "8*1ゟ12----";
if (!extractor.Extract(input, &float_features) ||
!FeaturesNear(input, {{chrome_lang_id::kScriptHiragana, 1.00}},
float_features)) {
test_successful = false;
}
return test_successful;
}
} // namespace relevant_script_feature_test
} // namespace chrome_lang_id
// Runs the feature extraction tests.
int main(int argc, char **argv) {
const bool tests_successful =
chrome_lang_id::relevant_script_feature_test::TestCommonCases() &&
chrome_lang_id::relevant_script_feature_test::TestCornerCases();
return tests_successful ? 0 : 1;
}