Supervised learning
M. Benesty
2017-09-18
library(fastrtext)
data("train_sentences")
data("test_sentences")
# prepare data
tmp_file_model <- tempfile()
train_labels <- paste0("__label__", train_sentences[,"class.text"])
train_texts <- tolower(train_sentences[,"text"])
train_to_write <- paste(train_labels, train_texts)
train_tmp_file_txt <- tempfile()
writeLines(text = train_to_write, con = train_tmp_file_txt)
test_labels <- paste0("__label__", test_sentences[,"class.text"])
test_texts <- tolower(test_sentences[,"text"])
test_to_write <- paste(test_labels, test_texts)
# learn model
execute(commands = c("supervised", "-input", train_tmp_file_txt, "-output", tmp_file_model, "-dim", 20, "-lr", 1, "-epoch", 20, "-wordNgrams", 2, "-verbose", 1))
##
Read 0M words
## Number of words: 5060
## Number of labels: 15
##
Progress: 100.0% words/sec/thread: 1474675 lr: 0.000000 loss: 0.321715 eta: 0h0m
# load model
model <- load_model(tmp_file_model)
## add .bin extension to the path
# prediction are returned as a list with words and probabilities
predictions <- predict(model, sentences = test_to_write)
print(head(predictions, 5))
## [[1]]
## __label__OWNX
## 0.9980469
##
## [[2]]
## __label__MISC
## 0.9785156
##
## [[3]]
## __label__MISC
## 0.9902344
##
## [[4]]
## __label__OWNX
## 0.9023438
##
## [[5]]
## __label__AIMX
## 0.9863281
# Compute accuracy
mean(sapply(predictions, names) == test_labels)
## [1] 0.8316667
# because there is only one category by observation, hamming loss will be the same
get_hamming_loss(as.list(test_labels), predictions)
## [1] 0.8316667
# test predictions
predictions <- predict(model, sentences = test_to_write)
print(head(predictions, 5))
## [[1]]
## __label__OWNX
## 0.9980469
##
## [[2]]
## __label__MISC
## 0.9785156
##
## [[3]]
## __label__MISC
## 0.9902344
##
## [[4]]
## __label__OWNX
## 0.9023438
##
## [[5]]
## __label__AIMX
## 0.9863281
# free memory
unlink(train_tmp_file_txt)
unlink(tmp_file_model)
rm(model)
gc()
## used (Mb) gc trigger (Mb) max used (Mb)
## Ncells 551597 29.5 940480 50.3 940480 50.3
## Vcells 1141478 8.8 1946338 14.9 1554547 11.9