Friday, March 26, 2010

ID3-impl

Below is my implementation of the ID3 algorithm based on my story about it.

It builds decision tree for next training data:

 AGE | COMPETITION | TYPE | PROFIT
 =========================================
 old | yes       | swr | down (False in my impl)
 --------+-------------+---------+--------
 old | no       | swr  | down
 --------+-------------+---------+--------
 old | no       | hwr | down
 --------+-------------+---------+--------
 mid | yes       | swr | down
 --------+-------------+---------+--------
 mid | yes       | hwr | down
 --------+-------------+---------+--------
 mid | no       | hwr | up (True in my impl)
 --------+-------------+---------+--------
 mid | no       | swr | up
 --------+-------------+---------+--------
 new | yes       | swr | up
 --------+-------------+---------+--------
 new | no       | hwr | up
 --------+-------------+---------+--------
 new | no       | swr | up
 --------+-------------+---------+--------

And built tree looks like this:

           Age
         / |    \
        /  |     \
    new/   |mid   \old
      /    |       \
    True Competition False
         /      \
        /        \
     no/          \yes
      /            \
    True             False



The Implementation of algorithm ID3


using System;
using System.Collections.Generic;
using System.Linq;

namespace ID3
{
    public static class Program
    {

        static void Main(string[] args)
        {
            var R = new Dictionary<string, List<string>>();
            R.Add("Age", new List<string>() { "old", "mid", "new" });
            R.Add("Competition", new List<string>() { "yes", "no" });
            R.Add("Type", new List<string>() { "hwr", "swr" });


            var C = "Profit";
            var TrainingSet = GetTrainingData();
            var algorithm = new Id3Algorithm();
            Tree desicionTree = algorithm.ID3(R, C, "root", TrainingSet);
        }

        private static List<TrainingRecord> GetTrainingData()
        {
            var trainingRecords = new List<TrainingRecord>();
            Dictionary<string, string> attributes;

            attributes = new Dictionary<string, string>();
            attributes.Add("Age", "old");
            attributes.Add("Competition", "yes");
            attributes.Add("Type", "swr");
            trainingRecords.Add(new TrainingRecord(attributes, false));

            attributes = new Dictionary<string, string>();
            attributes.Add("Age", "old");
            attributes.Add("Competition", "no");
            attributes.Add("Type", "swr");
            trainingRecords.Add(new TrainingRecord(attributes, false));

            attributes = new Dictionary<string, string>();
            attributes.Add("Age", "old");
            attributes.Add("Competition", "no");
            attributes.Add("Type", "hwr");
            trainingRecords.Add(new TrainingRecord(attributes, false));

            attributes = new Dictionary<string, string>();
            attributes.Add("Age", "mid");
            attributes.Add("Competition", "yes");
            attributes.Add("Type", "swr");
            trainingRecords.Add(new TrainingRecord(attributes, false));

            attributes = new Dictionary<string, string>();
            attributes.Add("Age", "mid");
            attributes.Add("Competition", "yes");
            attributes.Add("Type", "hwr");
            trainingRecords.Add(new TrainingRecord(attributes, false));

            attributes = new Dictionary<string, string>();
            attributes.Add("Age", "mid");
            attributes.Add("Competition", "no");
            attributes.Add("Type", "hwr");
            trainingRecords.Add(new TrainingRecord(attributes, true));

            attributes = new Dictionary<string, string>();
            attributes.Add("Age", "mid");
            attributes.Add("Competition", "no");
            attributes.Add("Type", "swr");
            trainingRecords.Add(new TrainingRecord(attributes, true));

            attributes = new Dictionary<string, string>();
            attributes.Add("Age", "new");
            attributes.Add("Competition", "yes");
            attributes.Add("Type", "swr");
            trainingRecords.Add(new TrainingRecord(attributes, true));

            attributes = new Dictionary<string, string>();
            attributes.Add("Age", "new");
            attributes.Add("Competition", "no");
            attributes.Add("Type", "hwr");
            trainingRecords.Add(new TrainingRecord(attributes, true));

            attributes = new Dictionary<string, string>();
            attributes.Add("Age", "new");
            attributes.Add("Competition", "no");
            attributes.Add("Type", "swr");
            trainingRecords.Add(new TrainingRecord(attributes, true));


            return trainingRecords;
        }
    }

    internal class Tree
    {
        public string Name { get; private set; }
        public string ArcName { get; private set; }
        public bool IsLeaf{ get; private set; }

        public Dictionary<string, Tree> Trees { get; private set; }

        public Tree(string name, string arcName, Dictionary<string, Tree> trees)
        {
            Name = name;
            ArcName = arcName;
            Trees = trees;
            if (Trees == null) IsLeaf = true;
            else if (Trees.Count <= 0) IsLeaf = true;
        }
    }

    internal class TrainingRecord
    {
        public Dictionary<string, string> Attributes { get; private set; }
        public bool Success { get; private set; }

        public TrainingRecord(Dictionary<string, string> attributes, bool success)
        {
            Attributes = attributes;
            Success = success;
        }
    }


    /*    function ID3 (R: множина некатегоризаційних властивостей,
       C: категоризаційна властивість,
       S: множина для навчання) returns дерево прийняття рішень;
       begin
     Якщо S пуста, повернути один вузол із значенням невдача;
     Якщо S складаєтсья із рядків, для яких значення категоризаційної
        властивості одне й те ж,
        повернути єдиний вузол із тим значенням;
     Якщо R пуста, тоді повернути єдиний вузол із значенням, яке є
        найбільш частішим серед значень катеригоційної властивості,
        що було знайдено серед рядків S;
     Нехай D є властивістю із найбільшим приростом Gain(D,S)
        серед властивостей в множині R;
     Нехай {dj| j=1,2, .., m} - значення властивості D;
     Нехай {Sj| j=1,2, .., m} - підмножини S, що включають
        відповідні рядки із значенням dj для властивості D;
     Повернути дерево із коренем поміченим D і дуги позначені
        d1, d2, .., dm що продовжуються наступними деревами

          ID3(R-{D}, C, S1), ID3(R-{D}, C, S2), .., ID3(R-{D}, C, Sm);
       end ID3;
 */

    internal class Id3Algorithm
    {
        public Tree ID3(Dictionary<string, List<string>> R, string C, string arcName, List<TrainingRecord> S)
        {
            //1
            if(S.Count <= 0) return new Tree(false.ToString(), arcName, null);
            //2
            var prevValue = S[0].Success;
            foreach (var trainingRecord in S)
            {
                if(prevValue != trainingRecord.Success)
                {
                    prevValue = trainingRecord.Success;
                    break;
                }
            }
            if(prevValue == S[0].Success)
            {
                return new Tree(prevValue.ToString(), arcName, null);
            }
            //3
            if(R.Count <= 0)
            {
                var sCount = S.Where(x => x.Success).Count();
                var fCount = S.Where(x => !x.Success).Count();
                var resValue = (sCount < fCount) ? true : false;

                new Tree(resValue.ToString(), arcName, null);
            }
            //4
            double maxGain = double.MinValue;
            string maxAtrb = string.Empty;
            foreach (var attribute in R)
            {
                double currGain = Gain(attribute.Key, attribute.Value, S);
                if(currGain > maxGain)
                {
                    maxGain = currGain;
                    maxAtrb = attribute.Key;
                }
            }

            var partitioning = new Dictionary<string, List<TrainingRecord>>();

            foreach (var posValue in R[maxAtrb])
            {
                var Si = S.Where(x => x.Attributes[maxAtrb] == posValue).ToList();
                partitioning.Add(posValue, Si);
            }
            R.Remove(maxAtrb);

            var childTrees = new Dictionary<string, Tree>();
            foreach (var Si in partitioning)
            {
                childTrees.Add(Si.Key, ID3(R, C, Si.Key, Si.Value));
            }

            return new Tree(maxAtrb, arcName, childTrees);
        }

        private double Gain(string s, List<string> posValues, List<TrainingRecord> trainingRecords)
        {
            return Info(trainingRecords) - Info(s, posValues, trainingRecords);
        }

        private double Info(string attribute, List<string> posValues, List<TrainingRecord> list)
        {
            double nGeneral = list.Count;
            double sum = 0;
            foreach (var value in posValues)
            {
                var sCount = list.Where(x => (x.Attributes[attribute] == value) && x.Success).Count();
                var fCount = list.Where(x => (x.Attributes[attribute] == value) && (!x.Success)).Count();
                var n = (double)(sCount + fCount);
                var iValue = I(new List<double>() { sCount / n, fCount / n });
                sum += (n / nGeneral) * iValue;
            }
            return sum;
        }

        private double Info(List<TrainingRecord> trainingRecords)
        {
            int n = trainingRecords.Count;
            var sCount = trainingRecords.Where(x => x.Success == true).Count();
            var fCount = trainingRecords.Where(x => x.Success == false).Count();
            var p1 = sCount / (double)n;
            var p2 = fCount / (double)n;

            double info = I(new List<double>() { p1, p2 });
            return info;
        }

        private double I(List<double> P)
        {
            double sum = 0;
            foreach (var p in P)
            {
                if (p != 0)
                {
                    sum += p * Math.Log(p, 2);
                }
            }
            return -sum;
        }
    }
}

and the result in Competition node from debug mode:


That is bold path in tree below:


           Age 
         / |    \
        /  |     \
    new/   |mid   \old
      /    |       \
    True Competition False
         /      \
        /        \
     no/          \yes
      /            \
    True             False

I'm going to implement all  this stuff online tomorrow for students who will listen to me.

0 comments:

Post a Comment