当前位置:网站首页>Using transformers of hugging face to realize multi label text classification
Using transformers of hugging face to realize multi label text classification
2022-06-26 08:26:00 【xuanningmeng】
Multi label classification
Text classification is one of the basic tasks of naturallanguageprocessing . Most text categorization is multi - category , That is, the data has multiple labels . Multi label text will be encountered in actual work or project . I use hugging face Of Transformers Realize multi label text classification . Author's tensorflow Version is 2.4.0,transformers The version is 4.2.0
Data processing
utilize transformers Medium BertTokenizer On data Tokenizer. The code is as follows :
def get_model_data(data, labels, max_seq_len=128):
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese", do_lower_case=True)
dataset_dict = {
"input_ids": [],
"attention_mask": [],
"label": []
}
assert len(data) == len(labels)
for i in range(len(data)):
sentence = data[i]
input_ids = tokenizer.encode(
sentence, # Sentence to encode.
add_special_tokens=True, # Add '[CLS]' and '[SEP]'
max_length=max_seq_len, # Truncate all sentences.
)
sentence_length = len(input_ids)
input_ids = pad_sequences([input_ids],
maxlen=max_seq_len,
dtype="long",
value=0,
truncating="post",
padding="post")
input_ids = input_ids.tolist()[0]
attention_mask = [1] * sentence_length + [0] * (max_seq_len - sentence_length)
dataset_dict["input_ids"].append(input_ids)
# dataset_dict["token_type_ids"].append(token_type_ids)
dataset_dict["attention_mask"].append(attention_mask)
dataset_dict["label"].append(labels[i])
for key in dataset_dict:
dataset_dict[key] = np.array(dataset_dict[key])
x = [
dataset_dict["input_ids"],
dataset_dict["attention_mask"],
]
y = dataset_dict["label"]
return x, y
Multi label classification model
utilize Transformers Build a multi label classification model . The activation function of multi label classification in the last layer of the model is sigmoid, The activation function of multiple categories is softmax. The loss function of multi label classification is BinaryCrossentropy. The code is as follows :
class BertMultiClassifier(object):
def __init__(self, bert_model_name, label_num):
self.label_num = label_num
self.bert_model_name = bert_model_name
def get_model(self):
bert = TFBertModel.from_pretrained(self.bert_model_name)
input_ids = Input(shape=(None,), dtype=tf.int32, name="input_ids")
attention_mask = Input(shape=(None,), dtype=tf.int32, name="attention_mask")
outputs = bert(input_ids, attention_mask=attention_mask)[1]
cla_outputs = Dense(self.label_num, activation='sigmoid')(outputs)
model = Model(
inputs=[input_ids, attention_mask],
outputs=[cla_outputs])
return model
def create_model(bert_model_name, label_nums):
model = BertMultiClassifier(bert_model_name, label_nums).get_model()
optimizer = tf.keras.optimizers.Adam(lr=1e-5)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=False)
model.compile(optimizer=optimizer, loss=loss_object,
metrics=['accuracy', tf.keras.metrics.Precision(),
tf.keras.metrics.Recall(),
tf.keras.metrics.AUC()]) # metrics=['accuracy']
return model
model training
utilize tensorflow The higher order in API keras Training models , Save model as h5, Save the model as pb Model . The training code is as follows :
model.fit(train_x, train_y, epochs=args["epoch"], verbose=1,
batch_size=args["batch_size"],
callbacks=callbacks,
validation_data=(val_x, val_y),
validation_batch_size=args["batch_size"])
model_path = os.path.join("./output/model/", "mulclassifition.h5")
model.save_weights(model_path)
tf.keras.models.save_model(model, args["pbmodel_path"], save_format="tf", overwrite=True)
Load model predictions
The general training model can be directly loaded to make prediction, and can also be used Tensorflow serving Deployment provision http service , The author introduces these two methods respectively . The code to directly load the model for prediction is as follows :
def predict(test_data, args, label_num):
# test_steps_per_epoch = len(test_data) // args["batch_size"]
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese", do_lower_case=True)
testdata = get_model_data(test_data, tokenizer, args["max_length"])
print("testdata: ", testdata)
model = create_model(args['bert_model_name'], label_num)
model.load_weights("./output/model/mulclassifition.h5")
pred_logits = model.predict(testdata, batch_size=args["batch_size"])
pred = np.where(pred_logits >= 0.5, 1, 0).tolist()
return pred
HTTP service
utilize Tensorflow serving and Flask Provide HTTP service . The code is as follows :
@app.route("/multiclassfier", methods=['POST'])
def multiclassifier_pred():
data_para = json.loads(request.get_data(), encoding="utf-8")
sentence = data_para["sent"]
print("sentence: ", sentence)
# get model input
test_x = get_model_data(sentence, tokenizer, 256)
input_ids = test_x[0]
attention_mask = test_x[1]
data = json.dumps({
"signature_name": "serving_default",
"inputs": {
"input_ids": input_ids,
"attention_mask": attention_mask}})
headers = {
"content-type": "application/json"}
result = requests.post("http://ip:port/v1/models/multiclass:predict", data=data, headers=headers)
if result.status_code == 200:
result = json.loads(result.text)
pred_logits = np.array(result["outputs"])
pred = np.where(pred_logits >= 0.5, 1, 0).tolist()
pred_encoder = label_encoder(pred, label)
return_result = {
"code": 200, "sent": sentence, "label": pred_encoder[0]}
return jsonify(return_result)
else:
return jsonify({
"code": result.status_code,
"message": traceback.format_exc()})
In code http://ip:port/v1/models/multiclass:predict yes tensorflow serving Load model to make prediction service . use docker Deploy tensorflow serving Deployment Services .
边栏推荐
- Introduction to uni app grammar
- Batch modify file name
- Macro task, micro task, async, await principle of interview
- Oracle 19C local listener configuration error - no listener
- Chapter 5 (array)
- Vs2019-mfc setting edit control and static text font size
- "System error 5 occurred when win10 started mysql. Access denied"
- Baoyan postgraduate entrance examination interview - Network
- Chapter VII (structure)
- (5) Matrix key
猜你喜欢

StarWar armor combined with scanning target location

Embedded Software Engineer (6-15k) written examination interview experience sharing (fresh graduates)

(4) Independent key

Application of wireless charging receiving chip xs016 coffee mixing cup
GHUnit: Unit Testing Objective-C for the iPhone

(5) Matrix key

Example of offset voltage of operational amplifier

"System error 5 occurred when win10 started mysql. Access denied"

Use of jupyter notebook

Vs2019-mfc setting edit control and static text font size
随机推荐
Chapter VII (structure)
STM32 project design: an e-reader making tutorial based on stm32f4
(vs2019 MFC connects to MySQL) make a simple login interface (detailed)
I want to open a stock account at a discount. How do I do it? Is it safe to open a mobile account?
Embedded Software Engineer (6-15k) written examination interview experience sharing (fresh graduates)
Click the button to call the system browser to open Baidu home page
When loading view, everything behind is shielded
Oracle database self study notes
Uniapp uses uviewui
Interview JS and browser
Handwritten instanceof underlying principle
XXL job configuration alarm email notification
Application of wireless charging receiving chip xs016 coffee mixing cup
Win11 open folder Caton solution summary
Uni app is similar to Taobao in selecting multiple specifications of commodities (inventory judgment)
(1) Turn on the LED
Discrete device ~ diode triode
Idea update
I Summary Preface
. eslintrc. JS configuration