1 package edu.uci.ics.jung.algorithms.shortestpath;
2
3 import java.util.Collection;
4 import java.util.HashSet;
5 import java.util.Set;
6
7 import org.apache.commons.collections15.Factory;
8 import org.apache.commons.collections15.Transformer;
9 import org.apache.commons.collections15.functors.ConstantTransformer;
10
11 import edu.uci.ics.jung.graph.Graph;
12 import edu.uci.ics.jung.graph.util.Pair;
13
14
15
16
17
18
19
20
21
22
23 @SuppressWarnings("unchecked")
24 public class PrimMinimumSpanningTree<V,E> implements Transformer<Graph<V,E>,Graph<V,E>> {
25
26 protected Factory<? extends Graph<V,E>> treeFactory;
27 protected Transformer<E,Double> weights;
28
29
30
31
32 public PrimMinimumSpanningTree(Factory<? extends Graph<V,E>> factory) {
33 this(factory, new ConstantTransformer(1.0));
34 }
35
36
37
38
39 public PrimMinimumSpanningTree(Factory<? extends Graph<V,E>> factory,
40 Transformer<E, Double> weights) {
41 this.treeFactory = factory;
42 if(weights != null) {
43 this.weights = weights;
44 }
45 }
46
47
48
49
50 public Graph<V,E> transform(Graph<V,E> graph) {
51 Set<E> unfinishedEdges = new HashSet<E>(graph.getEdges());
52 Graph<V,E> tree = treeFactory.create();
53 V root = findRoot(graph);
54 if(graph.getVertices().contains(root)) {
55 tree.addVertex(root);
56 } else if(graph.getVertexCount() > 0) {
57
58 tree.addVertex(graph.getVertices().iterator().next());
59 }
60 updateTree(tree, graph, unfinishedEdges);
61
62 return tree;
63 }
64
65 protected V findRoot(Graph<V,E> graph) {
66 for(V v : graph.getVertices()) {
67 if(graph.getInEdges(v).size() == 0) {
68 return v;
69 }
70 }
71
72 if(graph.getVertexCount() > 0) {
73 return graph.getVertices().iterator().next();
74 }
75
76 return null;
77 }
78
79 protected void updateTree(Graph<V,E> tree, Graph<V,E> graph, Collection<E> unfinishedEdges) {
80 Collection<V> tv = tree.getVertices();
81 double minCost = Double.MAX_VALUE;
82 E nextEdge = null;
83 V nextVertex = null;
84 V currentVertex = null;
85 for(E e : unfinishedEdges) {
86
87 if(tree.getEdges().contains(e)) continue;
88
89
90 Pair<V> endpoints = graph.getEndpoints(e);
91 V first = endpoints.getFirst();
92 V second = endpoints.getSecond();
93 if((tv.contains(first) == true && tv.contains(second) == false)) {
94 if(weights.transform(e) < minCost) {
95 minCost = weights.transform(e);
96 nextEdge = e;
97 currentVertex = first;
98 nextVertex = second;
99 }
100 } else if((tv.contains(second) == true && tv.contains(first) == false)) {
101 if(weights.transform(e) < minCost) {
102 minCost = weights.transform(e);
103 nextEdge = e;
104 currentVertex = second;
105 nextVertex = first;
106 }
107 }
108 }
109
110 if(nextVertex != null && nextEdge != null) {
111 unfinishedEdges.remove(nextEdge);
112 tree.addEdge(nextEdge, currentVertex, nextVertex);
113 updateTree(tree, graph, unfinishedEdges);
114 }
115 }
116 }