Spaces:
Build error
Build error
void utreplace(std::string & str, const std::string & needle, const std::string & replacement) { | |
size_t pos = 0; | |
while ((pos = str.find(needle, pos)) != std::string::npos) { | |
str.replace(pos, needle.length(), replacement); | |
pos += replacement.length(); | |
} | |
} | |
std::map<std::string, int32_t> json_parse(const std::string & fname) { | |
std::map<std::string, int32_t> result; | |
// read file into string | |
std::string json; | |
{ | |
std::ifstream ifs(fname); | |
if (!ifs) { | |
fprintf(stderr, "Failed to open %s\n", fname.c_str()); | |
exit(1); | |
} | |
json = std::string((std::istreambuf_iterator<char>(ifs)), | |
(std::istreambuf_iterator<char>())); | |
} | |
if (json[0] != '{') { | |
return result; | |
} | |
// parse json | |
{ | |
bool has_key = false; | |
bool in_token = false; | |
std::string str_key = ""; | |
std::string str_val = ""; | |
int n = json.size(); | |
for (int i = 1; i < n; ++i) { | |
if (!in_token) { | |
if (json[i] == ' ') continue; | |
if (json[i] == '"') { | |
in_token = true; | |
continue; | |
} | |
} else { | |
if (json[i] == '\\' && i+1 < n) { | |
if (has_key == false) { | |
str_key += json[i]; | |
} else { | |
str_val += json[i]; | |
} | |
++i; | |
} else if (json[i] == '"') { | |
if (has_key == false) { | |
has_key = true; | |
++i; | |
while (json[i] == ' ') ++i; | |
++i; // : | |
while (json[i] == ' ') ++i; | |
if (json[i] != '\"') { | |
while (json[i] != ',' && json[i] != '}') { | |
str_val += json[i++]; | |
} | |
has_key = false; | |
} else { | |
in_token = true; | |
continue; | |
} | |
} else { | |
has_key = false; | |
} | |
::utreplace(str_key, "\\u0120", " " ); // \u0120 -> space | |
::utreplace(str_key, "\\u010a", "\n"); // \u010a -> new line | |
::utreplace(str_key, "\\\"", "\""); // \\\" -> " | |
try { | |
result[str_key] = std::stoi(str_val); | |
} catch (...) { | |
//fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str()); | |
} | |
str_key = ""; | |
str_val = ""; | |
in_token = false; | |
continue; | |
} | |
if (has_key == false) { | |
str_key += json[i]; | |
} else { | |
str_val += json[i]; | |
} | |
} | |
} | |
} | |
return result; | |
} | |
void gpt_vocab::add_special_token(const std::string & token) { | |
special_tokens.push_back(token); | |
} | |
std::string convert_to_utf8(const std::wstring & input) { | |
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter; | |
return converter.to_bytes(input); | |
} | |
std::wstring convert_to_wstring(const std::string & input) { | |
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter; | |
return converter.from_bytes(input); | |
} | |
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) { | |
std::vector<std::string> words; | |
// first split the text into words | |
{ | |
std::string str = text; | |
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; | |
// Generate the subpattern from the special_tokens vector if it's not empty | |
if (!vocab.special_tokens.empty()) { | |
std::string special_tokens_subpattern; | |
for (const auto & token : vocab.special_tokens) { | |
if (!special_tokens_subpattern.empty()) { | |
special_tokens_subpattern += "|"; | |
} | |
special_tokens_subpattern += token; | |
} | |
// Modify the regex pattern with the generated special tokens subpattern | |
pat = special_tokens_subpattern + "|" + pat; | |
} | |
std::regex re(pat); | |
std::smatch m; | |
while (std::regex_search(str, m, re)) { | |
for (auto x : m) { | |
words.push_back(x); | |
} | |
str = m.suffix(); | |
} | |
} | |
// find the longest token that forms each word in words: | |
std::vector<gpt_vocab::id> tokens; | |
for (const auto & word : words) { | |
for (int i = 0; i < word.size(); ){ | |
for (int j = word.size() - 1; j >= i; j--){ | |
auto cand = word.substr(i, j-i+1); | |
auto it = vocab.token_to_id.find(cand); | |
if (it != vocab.token_to_id.end()){ // word.substr(i, j-i+1) in vocab | |
tokens.push_back(it->second); | |
i = j + 1; | |
break; | |
} | |
else if (j == i){ // word.substr(i, 1) has no matching | |
fprintf(stderr, "%s: unknown token '%s'\n", __func__, word.substr(i, 1).data()); | |
i++; | |
} | |
} | |
} | |
} | |
return tokens; | |
} | |
bool should_transpose_layer(std::string name) | |
{ | |
if(name.find(".mlp.fc_in.weight")!=std::string::npos || | |
name.find(".attn.out_proj.weight")!=std::string::npos || | |
name.find(".attn.q_proj.weight")!=std::string::npos || | |
name.find(".attn.k_proj.weight")!=std::string::npos || | |
name.find(".attn.v_proj.weight")!=std::string::npos || | |
name.find("/attn/c_attn/w")!=std::string::npos || | |
name.find("/attn/c_proj/w")!=std::string::npos || | |
name.find("/mlp/c_fc/w")!=std::string::npos || | |
name.find("/mlp/c_proj/w")!=std::string::npos) | |
{ | |
return true; | |
} | |
return false; | |
} |