Let's now apply the same model for multiclass classification. We will be using the 20 NewsGroup dataset for this. To train a CNN model, this dataset is small. We will still try to do it with a simpler problem. As discussed before, the 20 classes in this dataset have quite a lot of mixing and with SVM we get a maximum of 70% accuracy. Here, we will take the six broad categories of this dataset and try to build a CNN classifier. So, first we will map the 20 categories to the six broad categories. Following is the code to first load the dataset from scikit learn:
def load_20newsgroup_data(categories = None, subset='all'): data = fetch_20newsgroups(subset=subset, shuffle=True, remove=('headers', 'footers', ...