Related: Design Sparse Matrix
http://www.1point3acres.com/bbs/thread-137068-1-1.html
设计一个SparseVector class,包含set(long idx, int val), get(long idx), dotProduct(SparseVector otherVec)三个方法,
http://www.1point3acres.com/bbs/thread-137068-1-1.html
设计一个SparseVector class,包含set(long idx, int val), get(long idx), dotProduct(SparseVector otherVec)三个方法,
public class SparseVector { private int d; // dimension private ST<Integer, Double> st; // the vector, represented by index-value pairs
public SparseVector(int d) { this.d = d; this.st = new ST<Integer, Double>(); } public void put(int i, double value) { if (i < 0 || i >= d) throw new IndexOutOfBoundsException("Illegal index"); if (value == 0.0) st.delete(i); else st.put(i, value); } public double get(int i) { if (i < 0 || i >= d) throw new IndexOutOfBoundsException("Illegal index"); if (st.contains(i)) return st.get(i); else return 0.0; }
public double dot(SparseVector that) { if (this.d != that.d) throw new IllegalArgumentException("Vector lengths disagree"); double sum = 0.0; // iterate over the vector with the fewest nonzeros if (this.st.size() <= that.st.size()) { for (int i : this.st.keys()) if (that.st.contains(i)) sum += this.get(i) * that.get(i); } else { for (int i : that.st.keys()) if (this.st.contains(i)) sum += this.get(i) * that.get(i); } return sum; } public double dot(double[] that) { double sum = 0.0; for (int i : st.keys()) sum += that[i] * this.get(i); return sum; }
public SparseVector plus(SparseVector that) { if (this.d != that.d) throw new IllegalArgumentException("Vector lengths disagree"); SparseVector c = new SparseVector(d); for (int i : this.st.keys()) c.put(i, this.get(i)); // c = this for (int i : that.st.keys()) c.put(i, that.get(i) + c.get(i)); // c = c + that return c; }
}
public class ST<Key extends Comparable<Key>, Value> implements Iterable<Key> { private TreeMap<Key, Value> st; public ST() { st = new TreeMap<Key, Value>(); } public Value get(Key key) { if (key == null) throw new NullPointerException("called get() with null key"); return st.get(key); } public void put(Key key, Value val) { if (key == null) throw new NullPointerException("called put() with null key"); if (val == null) st.remove(key); else st.put(key, val); } public void delete(Key key) { if (key == null) throw new NullPointerException("called delete() with null key"); st.remove(key); } public boolean contains(Key key) { if (key == null) throw new NullPointerException("called contains() with null key"); return st.containsKey(key); } public int size() { return st.size(); } public boolean isEmpty() { return size() == 0; } public Iterable<Key> keys() { return st.keySet(); }eprecated Replaced by {@link #keys()}. */ public Iterator<Key> iterator() { return st.keySet().iterator(); } public Key min() { if (isEmpty()) throw new NoSuchElementException("called min() with empty symbol table"); return st.firstKey(); }throws NoSuchElementException if this symbol table is empty */ public Key max() { if (isEmpty()) throw new NoSuchElementException("called max() with empty symbol table"); return st.lastKey(); } public Key ceiling(Key key) { if (key == null) throw new NullPointerException("called ceiling() with null key"); Key k = st.ceilingKey(key); if (k == null) throw new NoSuchElementException("all keys are less than " + key); return k; } public Key floor(Key key) { if (key == null) throw new NullPointerException("called floor() with null key"); Key k = st.floorKey(key); if (k == null) throw new NoSuchElementException("all keys are greater than " + key); return k; } }