这里主要介绍三个部分

  1. id3算法的实现,包括决策树的实现、算法的实现,输出为自己定义的决策树模型,这里是用了arff格式的文件
  2. 决策树的持久化,在大型数据集中,决策树模型过大,训练完需要持久化到磁盘上,供下次使用,这里采用了xml进行持久化
  3. 决策树的使用,重新使用xml文件中的决策树模型来对测试数据进行测试

机器学习相关代码欢迎关注https://github.com/xixy/MachineLearningAlgorithm

1 训练

本次训练按照arff格式文件进行编码,训练集如下所示

@relation weather.symbolic
 
@attribute outlook {sunny, overcast, rainy}
@attribute temperature {hot, mild, cool}
@attribute humidity {high, normal}
@attribute windy {TRUE, FALSE}
@attribute play {yes, no}
 
@data
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no

决策树中的节点模型

/**
 * 用来描述决策树中的一个节点
 */
public class Node {

    private String attribute;// 属性名称
    private String value;// 属性值
    private List<Node> childs;// 子节点
    private String label;// 类别标记,如果是叶子结点的话

    public Node(String attribute, String value) {
        this.setAttribute(attribute);
        this.setValue(value);
    }

    public void addChild(Node child) {
        if (childs == null)
            childs = new ArrayList<Node>();
        childs.add(child);
    }

    public String getAttribute() {
        return attribute;
    }

    public void setAttribute(String attribute) {
        this.attribute = attribute;
    }

    public String getValue() {
        return value;
    }

    public void setValue(String value) {
        this.value = value;
    }

    public List<Node> getChilds() {
        return childs;
    }

    public void setChilds(List<Node> childs) {
        this.childs = childs;
    }

    public String getLabel() {
        return label;
    }

    public void setLabel(String label) {
        this.label = label;
    }
}

决策树模型

/**
 * 决策树的生成
 */
public class DecisionTree {
    private Node root;

    public DecisionTree() {
        root = new Node("DecisionTree", "NULL");
    }

    public Node getRoot() {
        return root;
    }

    public void setRoot(Node root) {
        this.root = root;
    }

    /**
     * 迭代产生决策树
     * 
     * @param node
     *            当前节点,已经给出了属性和属性值
     * @param data
     *            当前属性值下的数据集
     * @param selatt
     *            当前可用的属性index
     * @param attributevalue
     *            属性值list
     * @param attribute
     *            属性list
     */
    public static void buildDecisionTreeRecursively(Node node, ArrayList<String[]> data, LinkedList<Integer> selatt,
            ArrayList<ArrayList<String>> attributevalue, ArrayList<String> attribute) {
        // 如果数据已经是同一类别了,那就没有意义往下分了
        if (GainCalculator.isPureDataSet(data)) {
            node.setLabel(data.get(0)[data.get(0).length - 1]);
            return;
        }
        // 如果数据不够纯净
        // 选择熵最小的,也就是得到信息增益最大的,这里并没有计算g(D,A),而是计算了H(D|A),选择最小的H(D|A),即可得到最大的g(D,A)
        int minIndex = -1;// 属性index
        double minEntropy = Double.MAX_VALUE;// 最小熵
        for (int i = 0; i < selatt.size(); i++) {
            double entropy = GainCalculator.calNodeEntropy(data, selatt.get(i), attributevalue);
            if (entropy < minEntropy) {
                minIndex = selatt.get(i);
                minEntropy = entropy;
            }
        }

        // 获取属性名称
        String nodeName = attribute.get(minIndex);
        // 去掉已选属性
        selatt.remove(new Integer(minIndex));
        // 得到该属性的所有值
        ArrayList<String> attvalues = attributevalue.get(minIndex);
        // 按照不同属性值进行划分数据集
        for (String val : attvalues) {
            Node child = new Node(nodeName, val);
            node.addChild(child);
            ArrayList<String[]> subset = new ArrayList<String[]>();
            for (int i = 0; i < data.size(); i++) {
                if (data.get(i)[minIndex].equals(val)) {
                    subset.add(data.get(i));
                }
            }
            buildDecisionTreeRecursively(child, subset, selatt, attributevalue, attribute);
        }

    }

    /**
     * 构建决策树
     * 
     * @param data
     *            训练数据集
     * @param selatt
     *            属性index列表
     * @param attributevalue
     *            属性值列表
     * @param attribute
     *            属性列表
     * @return
     */
    public Node buildDecisionTree(ArrayList<String[]> data, LinkedList<Integer> selatt,
            ArrayList<ArrayList<String>> attributevalue, ArrayList<String> attribute) {
        buildDecisionTreeRecursively(root, data, selatt, attributevalue, attribute);
        return root;

    }

}

熵计算的类,这里的熵计算有一些小技巧,需要推导,

g(D,A)=H(D)-H(D|A),因为对于所有的A来说,H(D)相同,为了得到最大的信息增益g(D,A),只需要计算得到最小的H(D|A)即可

/**
 * 用于计算熵
 */
public class GainCalculator {

    /**
     * 判断数据集是否为同一类别的数据集
     * 
     * @param data
     *            数据集
     * @return
     */
    public static boolean isPureDataSet(ArrayList<String[]> data) {
        String[] data1 = data.get(0);
        int labelIndex = data1.length - 1;
        String label1 = data1[labelIndex];
        for (int i = 1; i < data.size(); i++) {
            if (!label1.equals(data.get(i)[labelIndex]))
                return false;
        }
        return true;
    }

    /**
     * 计算样本的熵,其中数组是同一属性值下的不同类别的数量
     * 
     * @param arr
     * @return
     */
    public static double getEntropy(int[] arr) {
        double entropy = 0.0;
        int sum = 0;
        for (int i = 0; i < arr.length; i++) {
            entropy -= arr[i] * Math.log(arr[i] + Double.MIN_VALUE) / Math.log(2);
            sum += arr[i];
        }
        entropy += sum * Math.log(sum + Double.MIN_VALUE) / Math.log(2);
        entropy /= sum;
        return entropy;
    }

    /**
     * 根据给定的原始数据中的子集,以当前第index个属性节点计算他的信息熵
     * 
     * @param subSet
     * @param attributeIndex
     * @param attributevalue
     * @return
     */
    public static double calNodeEntropy(ArrayList<String[]> subSet, int attributeIndex,
            ArrayList<ArrayList<String>> attributevalue) {
        int sum = subSet.size();
        int decatt = subSet.get(0).length - 1;// 预测label的index
        double entropy = 0.0;
        int[][] info = new int[attributevalue.get(attributeIndex).size()][];
        for (int i = 0; i < info.length; i++)
            info[i] = new int[attributevalue.get(decatt).size()];
        int[] count = new int[attributevalue.get(attributeIndex).size()];// 不同属性值的计数
        // 统计类别和属性关系
        for (int i = 0; i < sum; i++) {
            String nodevalue = subSet.get(i)[attributeIndex];// 属性值,例如outlook是sunny
            int nodeind = attributevalue.get(attributeIndex).indexOf(nodevalue);// 属性值对应的index,例如sunny对应{sunny,rainy}中的0
            count[nodeind]++;// 属性个数+1
            String decvalue = subSet.get(i)[decatt];// 第i个数的类别
            int decind = attributevalue.get(decatt).indexOf(decvalue);// 类别对应的index,例如yes对应{yes,no}中的0
            info[nodeind][decind]++;// 增加相应属性值下的相应类别+1
        }
        // 计算
        // |Di|/|D|*H(Di)
        for (int i = 0; i < info.length; i++) {
            entropy += getEntropy(info[i]) * count[i] / sum;
        }
        return entropy;
    }

}

读取arff文件的类,生成相应的训练数据、属性、属性值集合等


/**
 * 用来读取arff文件,生成相应的训练数据、属性、属性值等
 */
public class DataReader {
    public static final String patternString = "@attribute(.*)[{](.*?)[}]";

    /**
     * 读取arff文件,并给attribute、attributevalue、data赋值
     * 
     * @param file
     *            文件
     * @param attribute
     *            属性名称
     * @param attributevalue
     *            属性值
     * @param data
     *            得到的数据列表
     */
    public static void readARFF(File file, ArrayList<String> attribute, ArrayList<ArrayList<String>> attributevalue,
            ArrayList<String[]> data) {
        try {
            FileReader fr = new FileReader(file);
            BufferedReader br = new BufferedReader(fr);
            String line;
            Pattern pattern = Pattern.compile(patternString);
            while ((line = br.readLine()) != null) {
                Matcher matcher = pattern.matcher(line);
                if (matcher.find()) { // 读@attribute
                    attribute.add(matcher.group(1).trim());
                    String[] values = matcher.group(2).split(",");
                    ArrayList<String> al = new ArrayList<String>(values.length);
                    for (String value : values) {
                        al.add(value.trim());
                    }
                    attributevalue.add(al);
                } else if (line.startsWith("@data")) { // 读@data
                    while ((line = br.readLine()) != null) {
                        if (line == "")
                            continue;
                        String[] row = line.split(",");
                        data.add(row);
                    }
                } else {
                    continue;
                }
            }
            br.close();
        } catch (IOException e1) {
            e1.printStackTrace();
        }
    }

}

ID3算法的类,主要用于调用其他的类来完成模型的生成

public class ID3 {
    public ArrayList<String> attribute = new ArrayList<String>(); // 存储属性的名称
    public ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>(); // 存储每个属性的取值
    public ArrayList<String[]> data = new ArrayList<String[]>();; // 原始数据
    public int decatt; // 决策变量在属性集中的索引

    /**
     * 设置决策变量index
     * 
     * @param n
     *            index
     */
    public void setDec(int n) {
        if (n < 0 || n >= attribute.size()) {
            System.err.println("决策变量指定错误。");
            System.exit(2);
        }
        decatt = n;
    }

    /**
     * 表示决策变量
     * 
     * @param name
     *            据测变量名称
     */
    public void setDec(String name) {
        int n = attribute.indexOf(name);
        setDec(n);
    }

    /**
     * @param args
     */
    public static void main(String[] args) {
        ID3 inst = new ID3();
        DataReader.readARFF(new File(Setting.trainingfile), inst.attribute, inst.attributevalue, inst.data);

        inst.setDec("play");
        LinkedList<Integer> ll = new LinkedList<Integer>();// 非决策属性的index列表
        for (int i = 0; i < inst.attribute.size(); i++) {
            if (i != inst.decatt)
                ll.add(i);
        }
        ArrayList<Integer> al = new ArrayList<Integer>();// 所有的data的列表
        for (int i = 0; i < inst.data.size(); i++) {
            al.add(i);
        }
        DecisionTree dt = new DecisionTree();
        dt.buildDecisionTree(inst.data, ll, inst.attributevalue, inst.attribute);
        XmlGenerator xmlGenerator = new XmlGenerator();
        xmlGenerator.generateXml(dt);
        xmlGenerator.outputXmlFile(Setting.xmlfile);
        return;
    }

}

2 模型持久化

得到的DecisionTree实例持久化到xml文件中

/**
 * 根据决策树结构生成xml文件用于持久化
 */
public class XmlGenerator {

    public Document xmldoc = null;

    /**
     * 根据决策树生成xml
     * 
     * @param dt
     *            决策树模型
     * @return xmldocument
     */
    public Document generateXml(DecisionTree dt) {
        xmldoc = DocumentHelper.createDocument();
        Element root = xmldoc.addElement("root");
        Node treeroot = dt.getRoot();

        generateXmlNodeRecursively(root, treeroot);
        return xmldoc;
    }

    /**
     * 迭代生成xml node
     * 
     * @param e
     *            xml node
     * @param nd
     *            决策树中的node
     */
    public void generateXmlNodeRecursively(Element e, Node nd) {
        Element xmlChild = e.addElement(nd.getAttribute()).addAttribute("value", nd.getValue());
        if (nd.getLabel() != null) {
            xmlChild.setText(nd.getLabel());
            return;
        }
        for (Node child : nd.getChilds())
            generateXmlNodeRecursively(xmlChild, child);
    }

    /**
     * 输出xml文件
     * 
     * @param filename
     */
    public void outputXmlFile(String filename) {
        try {
            File file = new File(filename);
            if (!file.exists())
                file.createNewFile();
            FileWriter fw = new FileWriter(file);
            OutputFormat format = OutputFormat.createPrettyPrint(); // 美化格式
            XMLWriter output = new XMLWriter(fw, format);
            output.write(xmldoc);
            output.close();
        } catch (IOException e) {
            System.out.println(e.getMessage());
        }
    }

}

得到输出结果如下所示

<?xml version="1.0" encoding="UTF-8"?>

<root>
  <DecisionTree value="NULL">
    <outlook value="sunny">
      <humidity value="high">no</humidity>
      <humidity value="normal">yes</humidity>
    </outlook>
    <outlook value="overcast">yes</outlook>
    <outlook value="rainy">
      <windy value="TRUE">no</windy>
      <windy value="FALSE">yes</windy>
    </outlook>
  </DecisionTree>
</root>

3 使用持久化的模型进行预测

采用xml文件中存储的决策树模型来进行预测,这里直接采用xmldoc来进行操作,而没有将其转化为DecisionTree实例进行匹配,因此更具有普适性

/**
 * 采用得到的决策树进行预测
 */
public class Prediction {

	private ArrayList<String[]> testData = null;
	Document xmldoc;
	private Map<String, Integer> attributeIndex = null;

	/**
	 * 加载测试文件
	 */
	public void loadTestFile() {
		testData = new ArrayList<String[]>();

		try {
			FileReader fr = new FileReader(new File(Setting.testfile));
			@SuppressWarnings("resource")
			BufferedReader br = new BufferedReader(fr);
			String line;
			while ((line = br.readLine()) != null) {
				String[] values = line.split(",");
				testData.add(values);
			}
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		} catch (IOException e) {
			e.printStackTrace();
		}
	}

	/**
	 * 读取决策树的xml文件
	 */
	public void loadDecisionTree() {
		SAXReader reader = new SAXReader();
		try {
			xmldoc = reader.read(new File(Setting.xmlfile));
		} catch (DocumentException e) {
			e.printStackTrace();
		}
	}

	/**
	 * 通过递归来实现决策树的遍历和匹配
	 * 
	 * @param e
	 *            节点
	 * @param data
	 *            数据
	 * @return label
	 */
	public String validation(Element e, String[] data) {
		if (e.isRootElement() || e.getName().equals("DecisionTree")) {
			@SuppressWarnings("unchecked")
			Iterator<Element> it = e.elementIterator();
			while (it.hasNext()) {
				Element child = it.next();
				String result = validation(child, data);
				if (result != null)
					return result;
			}
		}

		Attribute attribute = e.attribute(0);
		String name = e.getName();
		if (!attributeIndex.containsKey(name))
			return null;
		int index = attributeIndex.get(name);
		String value = attribute.getValue();
		// 如果匹配当前节点条件,那么就以该节点为要求往下进行匹配
		if (value.equals(data[index])) {
			@SuppressWarnings("unchecked")
			Iterator<Element> it = e.elementIterator();
			// 如果是叶子结点,那么就直接给出类别
			if (it.hasNext() == false) {
				return e.getText();
			}
			while (it.hasNext()) {
				Element child = it.next();
				String result = validation(child, data);
				if (result != null)
					return result;
			}
		}
		return null;

	}

	/**
	 * 进行预测工作
	 */
	public void prediction() {
		// 首先获取到属性列表
		ID3 inst = new ID3();
		DataReader.readARFF(new File(Setting.trainingfile), inst.attribute, inst.attributevalue, inst.data);
		inst.setDec(inst.attribute.get(inst.attribute.size() - 1));
		LinkedList<Integer> ll = new LinkedList<Integer>();// 非决策属性的index列表

		attributeIndex = new HashMap<String, Integer>();// 属性与index的对应
		for (int i = 0; i < inst.attribute.size(); i++) {
			if (i != inst.decatt) {
				attributeIndex.put(inst.attribute.get(i), i);
				ll.add(i);
			}
		}
		Element root = xmldoc.getRootElement();
		// 对每一组数据,都进行规则上的遍历,采用宽度优先
		for (String[] data : testData) {
			String label = validation(root, data);
			System.out.println(label);
		}

	}

}

public class PredictionTest {

	/**
	 * @param args
	 */
	public static void main(String[] args) {

		Prediction p = new Prediction();
		p.loadDecisionTree();
		p.loadTestFile();
		p.prediction();

	}

}

即可得到测试结果的输出

登录发表评论 注册

反馈意见