File size: 8,292 Bytes
db56acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from tokenizers import Tokenizer
import torch
import numpy as np
import time
import os
from datetime import datetime


def process_string_into_pairs(input_str: str) -> list[str]:
    result = []
    i = 0
    n = len(input_str)

    while i < n:
        char = input_str[i]

        # 检查当前字符是否为小写字母
        if "a" <= char <= "z":
            # 检查是否有下一个字符,并且下一个字符也是小写字母(配对情况)
            if i + 1 < n and "a" <= input_str[i + 1] <= "z":
                result.append(char + input_str[i + 1])
                i += 2  # 跳过两个字符
            # 检查是否有下一个字符,并且下一个字符是空格(落单小写字母+空格 的特殊情况)
            elif i + 1 < n and input_str[i + 1] == " ":
                result.append(char)
                i += 2  # 跳过当前字母和后面的空格
            # 其他情况(落单小写字母,后面是其他字符或已到末尾)
            else:
                result.append(char)
                i += 1  # 只跳过当前一个字符
        # 如果当前字符不是小写字母
        else:
            result.append(char)
            i += 1  # 只跳过当前一个字符

    return result


def get_mask_from_string(input_str: str, tokenizer) -> torch.Tensor:
    pairs = process_string_into_pairs(input_str)
    masks = [
        f"<|mask_{pair}|>" if all(ord(i) < 128 for i in pair) else pair
        for pair in pairs
    ]
    mask_tensor = torch.tensor(
        [tokenizer.token_to_id(mask) for mask in masks], dtype=torch.long
    )
    return mask_tensor


def inference(model, input_str: str, tokenizer, device, threshold=0.9):
    model.eval()
    
    # Initialize NgramHashMapping
    engram_cfg = model.config.engram_config
    hash_mapping = None
    if engram_cfg is not None:
        from modeling_llada_engram import ModelConfig, EngramConfig, NgramHashMapping
        from dataclasses import fields
        # Prepare ModelConfig for NgramHashMapping
        backbone_config_dict = model.config.to_dict()
        # Filter out keys not in ModelConfig if necessary, but ModelConfig usually matches LLaDAConfig fields
        backbone_config = ModelConfig(**{k: v for k, v in backbone_config_dict.items() if k in [f.name for f in fields(ModelConfig)]})
        
        hash_mapping = NgramHashMapping(
            engram_vocab_size = engram_cfg.get('engram_vocab_size', [129280*5, 129280*5]),
            max_ngram_size    = engram_cfg.get('max_ngram_size', 3),
            n_embed_per_ngram = engram_cfg.get('n_embed_per_ngram', 512),
            n_head_per_ngram  = engram_cfg.get('n_head_per_ngram', 8),
            layer_ids         = engram_cfg.get('layer_ids', [1, 15]),
            pad_id            = engram_cfg.get('pad_id', 2),
            seed              = engram_cfg.get('seed', 0),
            config            = backbone_config,
        )

    with torch.no_grad():
        mask_tensor = get_mask_from_string(input_str, tokenizer).unsqueeze(0).to(device)
        # is_masked = torch.ones(mask_tensor.shape, dtype=torch.bool, device=device)
        is_masked = mask_tensor >= tokenizer.token_to_id("<|mask|>")
        rounds = 0
        while is_masked.any():
            rounds += 1

            output = model(input_ids=mask_tensor)[0]
            # Logit to probability
            output = torch.softmax(output, dim=-1)
            unmasked_any = False
            prob_info = []

            most_certain_token = (0, 0, 0) # (probability, index, token_id)
            # Check each token that still is_masked
            for i in range(mask_tensor.shape[1]):
                if is_masked[0, i]:
                    # Get the token with the highest probability
                    predicted_token = output[0, i].argmax().item()
                    prob_info.append(
                        f"{output[0, i, predicted_token].item():.2f} {tokenizer.id_to_token(predicted_token)}"
                    )
                    most_certain_token = max(
                        most_certain_token,
                        (output[0, i, predicted_token].item(), i, predicted_token)
                    )
                    # If the probability is above the threshold, replace the mask
                    if output[0, i, predicted_token].item() > threshold:
                        mask_tensor[0, i] = predicted_token
                        is_masked[0, i] = False
                        unmasked_any = True
                else:
                    prob_info.append("")
            if not unmasked_any:
                # Unmask the most certain one
                mask_tensor[0, most_certain_token[1]] = most_certain_token[2]
                is_masked[0, most_certain_token[1]] = False

            masked_str = "".join(
                (
                    tokenizer.id_to_token(mask_tensor[0, i].item())
                    if not is_masked[0, i]
                    else tokenizer.id_to_token(mask_tensor[0, i].item())[7:-2]
                )
                for i in range(mask_tensor.shape[1])
            )
            print(masked_str)


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = Tokenizer.from_file("tokenizer.json")

    # Load from local directory using AutoModel
    # Note: Ensure you have transformers installed and trust_remote_code=True
    try:
        from transformers import AutoModelForCausalLM
        model = AutoModelForCausalLM.from_pretrained(".", trust_remote_code=True).to(device)
    except Exception as e:
        print(f"Failed to load with AutoModel: {e}")
        print("Falling back to manual loading (if needed, but prefer AutoModel for validation)")
        # Fallback code removed for clarity as we want to enforce AutoModel structure
        raise e

    # To bfloat16
    model = model.to(torch.bfloat16) if device.type == "cuda" else model.float()
    print("Loaded model. Parameters:", sum(p.numel() for p in model.parameters()))

    threshold = 0.9
    
    while True:
        input_str = input("Enter a string to process: ")
        inference(model, input_str, tokenizer, device, threshold=threshold)
        print("")  # 空行分隔

# Input example: nhkzotdgjvdmleunkmiekz。
# Output: 黄河是中华民族的母亲河。

# Input example: mdflswsyelfl,eyxxmdswsyelfl,raxxmdelfl,otfixdzhfnjrugfoirmbisunswsyelfl。zhldxxdgun“mdfl”uvelflqhnvxtmdunkmpbofvjcjnnmdunsoirpbucheel。
# Output: 大型语言模型,也称大语言模型,简称大模型,是一种基于人工神经网络的语言模型。其名称中的“大型”指模型具有庞大的参数量以及巨大的训练数据规模。

# Input example: hgzz(Go o g l e )otfiwjpmrnxjuchkaf,hdidjifngmrnsdoovsoggn.
# Output:
# 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城.
# 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城。
# 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城。
# 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城。
# 谷歌(Google)是一家跨国科技公司,总部位于美国加州山景城。

# Input example: jxvuygvbotghtusvwtvbdt。auwvvbotcbghwhtkshdl?
# Output:
# 天对地,雨对风。大陆对长空。山lj对ke树,赤日对ljeb。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨tq晚霞红。
# 天对地,雨对风。大陆对长空。山lj对杂树,赤日对苍eb。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨雷晚霞红。
# 天对地,雨对风。大陆对长空。山lj对杂树,赤日对苍穹。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨雷晚霞红。
# 天对地,雨对风。大陆对长空。山苍对杂树,赤日对苍穹。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨雷晚霞红。
# (Expected Output: 天对地,雨对风。大陆对长空。山花对海树,赤日对苍穹。雷隐隐,雾蒙蒙。日下对天中。风高秋月白,雨霁晚霞红。)