740 lines
21 KiB
Plaintext
740 lines
21 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Jupyter Notebook for Interactive Labeling\n",
|
|
"______\n",
|
|
"\n",
|
|
"This Jupyter Notebook combines a manual and automated labeling technique.\n",
|
|
"It includes scikit learn's Label Propagation Algorithm.\n",
|
|
"By calculating estimated class probabilities, we decide whether a news article has to be labeled manually or can be labeled automatically.\n",
|
|
"For multiclass labeling, 3 classes are used.\n",
|
|
"\n",
|
|
"In each iteration we...\n",
|
|
"- check/correct the next 100 article labels manually.\n",
|
|
" \n",
|
|
"- apply the Label Propagation classification algorithm which returns a vector class_probs $(K_1, K_2, ... , K_6)$ per sample with the probabilities $K_i$ per class $i$. Estimated class labels are adopted automatically, if the estimated probability $K_x > 0.99$ with $x \\in {1,...,6}$.\n",
|
|
" \n",
|
|
"Please note: User instructions are written in upper-case.\n",
|
|
"__________\n",
|
|
"Version: 2019-02-04, Anne Lorenz"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import csv\n",
|
|
"import operator\n",
|
|
"import pickle\n",
|
|
"import random\n",
|
|
"\n",
|
|
"from ipywidgets import interact, interactive, fixed, interact_manual\n",
|
|
"import ipywidgets as widgets\n",
|
|
"from IPython.core.interactiveshell import InteractiveShell\n",
|
|
"from IPython.display import display\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"from LabelPropagation import LabelPropagation\n",
|
|
"from MNBInteractive import MNBInteractive"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Part I: Data preparation\n",
|
|
"\n",
|
|
"First, we import our data set of 10 000 business news articles from a csv file.\n",
|
|
"It contains 833/834 articles of each month of the year 2017.\n",
|
|
"For detailed information regarding the data set, please read the full documentation."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# round number to save intermediate label status of data set\n",
|
|
"m = -1\n",
|
|
"\n",
|
|
"# initialize random => reproducible sequence\n",
|
|
"random.seed(5)\n",
|
|
"\n",
|
|
"filepath = '../data/cleaned_data_set_without_header.csv'\n",
|
|
"\n",
|
|
"# set up wider display area\n",
|
|
"pd.set_option('display.max_colwidth', -1)\n",
|
|
"\n",
|
|
"# set precision of output\n",
|
|
"np.set_printoptions(precision=3)\n",
|
|
"\n",
|
|
"# show full text for print statement\n",
|
|
"InteractiveShell.ast_node_interactivity = \"all\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Number of samples in data set in total: 10000\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"df = pd.read_csv(filepath,\n",
|
|
" header=None,\n",
|
|
" sep='|',\n",
|
|
" engine='python',\n",
|
|
" names = [\"Uuid\", \"Title\", \"Text\", \"Site\", \"SiteSection\", \"Url\", \"Timestamp\"],\n",
|
|
" decimal='.',\n",
|
|
" quotechar='\\'',\n",
|
|
" quoting=csv.QUOTE_NONNUMERIC)\n",
|
|
"\n",
|
|
"# add column for indices\n",
|
|
"df['Index'] = df.index.values.astype(int)\n",
|
|
"\n",
|
|
"# add round annotation (indicates labeling time)\n",
|
|
"df['Round'] = np.nan\n",
|
|
"\n",
|
|
"# initialize label column with -1 for unlabeled samples\n",
|
|
"df['Label'] = np.full((len(df)), -1).astype(int)\n",
|
|
"\n",
|
|
"# add column for estimated probability\n",
|
|
"df['Probability'] = np.nan\n",
|
|
"\n",
|
|
"# store auto-estimated label, initialize with -1 for unestimated samples\n",
|
|
"df['Estimated'] = np.full((len(df)), -1).astype(int)\n",
|
|
"\n",
|
|
"# row number\n",
|
|
"n_rows = df.shape[0]\n",
|
|
"print('Number of samples in data set in total: {}'.format(n_rows))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We load the previously created dictionary of all article indices (keys) with a list of mentioned organizations (values).\n",
|
|
"In the following, we limit the number of occurences of a certain company name in all labeled articles to 3 to avoid imbalance."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def show_next(index):\n",
|
|
" ''' this method displays an article's text and an interactive slider to set its label manually\n",
|
|
" '''\n",
|
|
" print('News article no. {}:'.format(index))\n",
|
|
" print()\n",
|
|
" print('HEADLINE:')\n",
|
|
" print(df.loc[df['Index'] == index, 'Title'])\n",
|
|
" print()\n",
|
|
" print('TEXT:')\n",
|
|
" print(df.loc[df['Index'] == index, 'Text'])\n",
|
|
" \n",
|
|
" def f(x):\n",
|
|
" # save user input\n",
|
|
" df.loc[df['Index'] == index, 'Label'] = x\n",
|
|
" df.loc[df['Index'] == index, 'Round'] = m\n",
|
|
"\n",
|
|
" # create slider widget for labels\n",
|
|
" interact(f, x = widgets.IntSlider(min=-1, max=2, step=1, value=df.loc[df['Index'] == index, 'Estimated']))\n",
|
|
" print('0: Other/Unrelated news, 1: Merger,') \n",
|
|
" print('2: Topics related to deals, investments and mergers')\n",
|
|
" print('(e.g. merger pending/in talks/to be approved or merger rejected/aborted/denied or sale of unit or')\n",
|
|
" print('Share Deal/Asset Deal/acquisition or merger as incidental remark/not main topic/not current or speculative)')\n",
|
|
" print('___________________________________________________________________________________________________________')\n",
|
|
" print()\n",
|
|
" print()\n",
|
|
"\n",
|
|
"# list of article indices that will be shown next\n",
|
|
"label_next = []"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# global dict of all articles (article index => list of mentioned organizations)\n",
|
|
"dict_art_orgs = {}\n",
|
|
"with open('../obj/dict_articles_organizations_without_banks.pkl', 'rb') as input:\n",
|
|
" dict_art_orgs = pickle.load(input)\n",
|
|
"\n",
|
|
"# global dict of mentioned companies in labeled articles (company name => number of occurences\n",
|
|
"dict_limit = {}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The iteration part starts here:\n",
|
|
"\n",
|
|
"## Part II: Manual checking of estimated labels"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"PLEASE INSERT M MANUALLY IF PROCESS HAS BEEN INTERRUPTED BEFORE."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"m = 9"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Last round number: 9\n",
|
|
"Number of manually labeled articles: 1000\n",
|
|
"Number of manually unlabeled articles: 9000\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# read current data set from csv\n",
|
|
"df = pd.read_csv('../data/interactive_labeling_round_{}.csv'.format(m),\n",
|
|
" sep='|',\n",
|
|
" usecols=range(1,13), # drop first column 'unnamed'\n",
|
|
" encoding='utf-8',\n",
|
|
" quoting=csv.QUOTE_NONNUMERIC,\n",
|
|
" quotechar='\\'')\n",
|
|
"\n",
|
|
"# find current iteration/round number\n",
|
|
"m = int(df['Round'].max())\n",
|
|
"print('Last round number: {}'.format(m))\n",
|
|
"print('Number of manually labeled articles: {}'.format(len(df.loc[df['Label'] != -1])))\n",
|
|
"print('Number of manually unlabeled articles: {}'.format(len(df.loc[df['Label'] == -1])))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# initialize dict_limit\n",
|
|
"df_labeled = df[df['Label'] != -1]\n",
|
|
"\n",
|
|
"for index in df_labeled['Index']:\n",
|
|
" orgs = dict_art_orgs[index]\n",
|
|
" for org in orgs:\n",
|
|
" if org in dict_limit:\n",
|
|
" dict_limit[org] += 1\n",
|
|
" else:\n",
|
|
" dict_limit[org] = 1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# OPTIONAL:\n",
|
|
"# print organizations that are mentioned 3 times and therefore limited\n",
|
|
"for k, v in dict_limit.items():\n",
|
|
" if v == 3:\n",
|
|
" print(k)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We now check (and correct if necessary) the next 100 auto-labeled articles."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if m == -1:\n",
|
|
" indices = list(range(10000))\n",
|
|
"else:\n",
|
|
" # indices of recently auto-labeled articles\n",
|
|
" indices = df.loc[(df['Estimated'] != -1) & (df['Label'] == -1), 'Index'].tolist()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# increment round number\n",
|
|
"m += 1\n",
|
|
"print('This round number: {}'.format(m))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def pick_random_articles(n, limit = 3):\n",
|
|
" ''' pick n random articles, check if company occurences under limit.\n",
|
|
" returns list of n indices of the articles we can label next.\n",
|
|
" '''\n",
|
|
" # labeling list\n",
|
|
" list_arts = []\n",
|
|
" # article counter\n",
|
|
" i = 0\n",
|
|
" while i < n:\n",
|
|
" # pick random article\n",
|
|
" rand_i = random.choice(indices)\n",
|
|
" # list of companies in that article\n",
|
|
" companies = dict_art_orgs[rand_i]\n",
|
|
" if all((dict_limit.get(company) == None) or (dict_limit[company] < limit ) for company in companies): \n",
|
|
" for company in companies:\n",
|
|
" if company in dict_limit:\n",
|
|
" dict_limit[company] += 1\n",
|
|
" else:\n",
|
|
" dict_limit[company] = 1\n",
|
|
" # add article to labeling list\n",
|
|
" list_arts.append(rand_i)\n",
|
|
" indices.remove(rand_i)\n",
|
|
" i += 1\n",
|
|
" return list_arts"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# generate new list of article indices for labeling\n",
|
|
"batchsize = 100\n",
|
|
"label_next = pick_random_articles(batchsize)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"PLEASE READ THE FOLLOWING ARTICLES AND ENTER THE CORRESPONDING LABELS:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for index in label_next:\n",
|
|
" show_next(index)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print('Number of manual labels in round no. {}:'.format(m))\n",
|
|
"print('0:{}, 1:{}, 2:{}'.format(len(df.loc[(df['Label'] == 0) & (df['Round'] == m)]), len(df.loc[(df['Label'] == 1) & (df['Round'] == m)]), len(df.loc[(df['Label'] == 2) & (df['Round'] == m)])))\n",
|
|
"\n",
|
|
"print('Number of articles to be corrected in this round: {}'.format(len(df.loc[(df['Label'] != -1) & (df['Estimated'] != -1) & (df['Round'] == m) & (df['Label'] != df['Estimated'])])))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# save intermediate status\n",
|
|
"df.to_csv('../data/interactive_labeling_round_{}_temp.csv'.format(m),\n",
|
|
" sep='|',\n",
|
|
" mode='w',\n",
|
|
" encoding='utf-8',\n",
|
|
" quoting=csv.QUOTE_NONNUMERIC,\n",
|
|
" quotechar='\\'')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#df.loc[df['Label'] != -1][:100]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Part III: Model building and automated labeling"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# THIS CELL IS OPTIONAL\n",
|
|
"\n",
|
|
"# read current data set from csv\n",
|
|
"m = \n",
|
|
"df = pd.read_csv('../data/interactive_labeling_round_{}_temp.csv'.format(m),\n",
|
|
" sep='|',\n",
|
|
" usecols=range(1,13), # drop first column 'unnamed'\n",
|
|
" encoding='utf-8',\n",
|
|
" quoting=csv.QUOTE_NONNUMERIC,\n",
|
|
" quotechar='\\'')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We build a classification model and check if it is possible to label articles automatically."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"# MNB: starting label propagation\n",
|
|
"# BOW: extracting all words from articles...\n",
|
|
"\n",
|
|
"# BOW: making vocabulary of data set...\n",
|
|
"\n",
|
|
"# BOW: vocabulary consists of 14414 features.\n",
|
|
"\n",
|
|
"# MNB: fit training data and calculate matrix...\n",
|
|
"\n",
|
|
"# BOW: calculating matrix...\n",
|
|
"\n",
|
|
"# BOW: calculating frequencies...\n",
|
|
"\n",
|
|
"# MNB: transform testing data to matrix...\n",
|
|
"\n",
|
|
"# BOW: extracting all words from articles...\n",
|
|
"\n",
|
|
"# BOW: calculating matrix...\n",
|
|
"\n",
|
|
"# BOW: calculating frequencies...\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"C:\\Users\\Anne\\Anaconda3\\lib\\site-packages\\sklearn\\semi_supervised\\label_propagation.py:205: RuntimeWarning: invalid value encountered in true_divide\n",
|
|
" probabilities /= normalizer\n",
|
|
"C:\\Users\\Anne\\Anaconda3\\lib\\site-packages\\sklearn\\semi_supervised\\label_propagation.py:205: RuntimeWarning: invalid value encountered in true_divide\n",
|
|
" probabilities /= normalizer\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"# MNB: ending label propagation\n",
|
|
"Wall time: 41min 56s\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# use sklearn's CountVectorizer\n",
|
|
"cv = False\n",
|
|
"\n",
|
|
"# call script with manually labeled and manually unlabeled samples\n",
|
|
"%time class_probs, predictions = LabelPropagation.propagate_labels(df.loc[df['Label'] != -1], df.loc[df['Label'] == -1], cv)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We label each article with class $j$, if its estimated probability for class $j$ is higher than our threshold:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[[nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [ 1. 0. 0.]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]\n",
|
|
" [nan nan nan]]\n",
|
|
"[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
" 0. 0. 0. 0.]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(class_probs[:100])\n",
|
|
"print(predictions[:100])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# only labels with this minimum probability are adopted\n",
|
|
"threshold = 0.99\n",
|
|
"# dict for counting estimated labels\n",
|
|
"estimated_labels = {0:0, 1:0, 2:0}\n",
|
|
"\n",
|
|
"# series of indices of recently estimated articles \n",
|
|
"indices_estimated = df.loc[df['Label'] == -1, 'Index'].tolist()\n",
|
|
"\n",
|
|
"# for every row i and every element j in row i\n",
|
|
"for (i,j), value in np.ndenumerate(class_probs):\n",
|
|
" # check if probability of class i is not less than threshold\n",
|
|
" if class_probs[i][j] > threshold:\n",
|
|
" index = indices_estimated[i]\n",
|
|
" # save estimated label\n",
|
|
" df.loc[index, 'Estimated'] = classes[j]\n",
|
|
" # annotate probability\n",
|
|
" df.loc[index, 'Probability'] = value\n",
|
|
" # count labels\n",
|
|
" estimated_labels[int(classes[j])] += 1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print('Number of auto-labeled samples in round {}: {}'.format(m, sum(estimated_labels.values())))\n",
|
|
"print('Estimated labels: {}'.format(estimated_labels))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# THIS CELL IS OPTIONAL\n",
|
|
"# let the Naive Bayes Algorithm test the quality of data set's labels\n",
|
|
"\n",
|
|
"# split data into text and label set\n",
|
|
"X = df.loc[df['Label'] != -1, 'Title'] + '. ' + df.loc[df['Label'] != -1, 'Text']\n",
|
|
"X = X.reset_index(drop=True)\n",
|
|
"y = df.loc[df['Label'] != -1, 'Label']\n",
|
|
"y = y.reset_index(drop=True)\n",
|
|
"\n",
|
|
"# use sklearn's CountVectorizer\n",
|
|
"cv = False\n",
|
|
"\n",
|
|
"# call script with manually labeled and manually unlabeled samples\n",
|
|
"#%time MNBInteractive.measure_mnb(X, y, cv)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print('End of this round (no. {}):'.format(m))\n",
|
|
"print('Number of manually labeled articles: {}'.format(len(df.loc[df['Label'] != -1])))\n",
|
|
"print('Number of manually unlabeled articles: {}'.format(len(df.loc[df['Label'] == -1])))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# save this round to csv\n",
|
|
"df.to_csv('../data/interactive_labeling_round_{}.csv'.format(m),\n",
|
|
" sep='|',\n",
|
|
" mode='w',\n",
|
|
" encoding='utf-8',\n",
|
|
" quoting=csv.QUOTE_NONNUMERIC,\n",
|
|
" quotechar='\\'')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"NOW PLEASE CONTINUE WITH PART II.\n",
|
|
"REPEAT UNTIL ALL SAMPLES ARE LABELED."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.1"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|