1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
|
import datetime
import json
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id
print (datetime.datetime.now())
def predict(inputs):
output = model(inputs)
return output.start_logits, output.end_logits
def construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id):
question_ids = tokenizer.encode(question, add_special_tokens=False)
text_ids = tokenizer.encode(text, add_special_tokens=False)
# construct input token ids
input_ids = [cls_token_id] + question_ids + [sep_token_id] + text_ids + [sep_token_id]
# construct reference token ids
ref_input_ids = [cls_token_id] + [ref_token_id] * len(question_ids) + [sep_token_id] + \
[ref_token_id] * len(text_ids) + [sep_token_id]
return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(question_ids)
def predict_qt(question, text):
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)
ground_truth = '13'
start_scores, end_scores = predict(input_ids)
return (' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))
def test_valisdation():
evl_dick = {}
for i in range(len(validation_contexts)):
question = validation_question[i]
text = validation_contexts[i]
if len(text) <= 512:
answer = predict_qt(question, text)
evl_dick[str(validation_ids[i])] = answer
time = str(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
json_dick = json.dumps(evl_dick)
filname = time + "answers.txt"
fo = open("./results/"+filname, "w",encoding='utf-8')
fo.write(json_dick)
fo.close()
|